Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@ public actor DataFrame: Sendable {
/// - Parameters:
/// - spark: A `SparkSession` instance to use.
/// - sqlText: A SQL statement.
init(spark: SparkSession, sqlText: String) async throws {
/// - posArgs: An array of strings.
init(spark: SparkSession, sqlText: String, _ posArgs: [Sendable]? = nil) async throws {
self.spark = spark
self.plan = sqlText.toSparkConnectPlan
if let posArgs {
self.plan = sqlText.toSparkConnectPlan(posArgs)
} else {
self.plan = sqlText.toSparkConnectPlan
}
}

init(spark: SparkSession, sqlText: String, _ args: [String: Sendable]) async throws {
self.spark = spark
self.plan = sqlText.toSparkConnectPlan(args)
}

public func getPlan() -> Sendable {
Expand Down
68 changes: 68 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,74 @@ extension String {
return plan
}

func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan {
var sql = Spark_Connect_SQL()
sql.query = self
sql.posArguments = posArguments.map {
var literal = ExpressionLiteral()
switch $0 {
case let value as Bool:
literal.boolean = value
case let value as Int8:
literal.byte = Int32(value)
case let value as Int16:
literal.short = Int32(value)
case let value as Int32:
literal.integer = value
case let value as Int64:
literal.long = value
case let value as Int:
literal.long = Int64(value)
case let value as String:
literal.string = value
default:
literal.string = $0 as! String
}
var expr = Spark_Connect_Expression()
expr.literal = literal
return expr
}
var relation = Relation()
relation.sql = sql
var plan = Plan()
plan.opType = Plan.OneOf_OpType.root(relation)
return plan
}

func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan {
var sql = Spark_Connect_SQL()
sql.query = self
sql.namedArguments = namedArguments.mapValues { value in
var literal = ExpressionLiteral()
switch value {
case let value as Bool:
literal.boolean = value
case let value as Int8:
literal.byte = Int32(value)
case let value as Int16:
literal.short = Int32(value)
case let value as Int32:
literal.integer = value
case let value as Int64:
literal.long = value
case let value as Int:
literal.long = Int64(value)
case let value as String:
literal.string = value
default:
literal.string = value as! String
}
var expr = Spark_Connect_Expression()
expr.literal = literal
return expr
}
var relation = Relation()
relation.sql = sql
var plan = Plan()
plan.opType = Plan.OneOf_OpType.root(relation)
return plan
}

/// Get a `UserContext` instance from a string.
var toUserContext: UserContext {
var context = UserContext()
Expand Down
20 changes: 20 additions & 0 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ public actor SparkSession {
return try await DataFrame(spark: self, sqlText: sqlText)
}

/// Executes a SQL query substituting positional parameters by the given arguments, returning the
/// result as a `DataFrame`.
/// - Parameters:
/// - sqlText: A SQL statement with positional parameters to execute.
/// - args: An array of strings that can be converted to SQL literal expressions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An array of strings?

Do they must be strings? Looks like they can be any Sendable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right. It's typo.

/// - Returns: A ``DataFrame``.
public func sql(_ sqlText: String, _ args: Sendable...) async throws -> DataFrame {
return try await DataFrame(spark: self, sqlText: sqlText, args)
}

/// Executes a SQL query substituting named parameters by the given arguments, returning the
/// result as a `DataFrame`.
/// - Parameters:
/// - sqlText: A SQL statement with named parameters to execute.
/// - args: A dictionary with key string and values.
/// - Returns: A ``DataFrame``.
public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> DataFrame {
return try await DataFrame(spark: self, sqlText: sqlText, args)
}

/// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a
/// `DataFrame`
public var read: DataFrameReader {
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ typealias Drop = Spark_Connect_Drop
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
typealias ExpressionLiteral = Spark_Connect_Expression.Literal
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
typealias GroupType = Spark_Connect_Aggregate.GroupType
Expand Down
13 changes: 13 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ struct SparkSessionTests {
await spark.stop()
}

#if !os(Linux)
@Test
func sql() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let expected = [Row(true, 1, "a")]
if await spark.version.starts(with: "4.") {
#expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected)
#expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() == expected)
}
await spark.stop()
}
#endif

@Test
func table() async throws {
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
Expand Down
Loading