Skip to content

Commit bb8b9fb

Browse files
committed
[SPARK-51986] Support Parameterized SQL queries in sql API
### What changes were proposed in this pull request? This PR aims to support `Parameterized SQL queries` in `sql` API. ### Why are the changes needed? For feature parity, we had better support this GA feature. - apache/spark#38864 (Since Spark 3.4.0) - apache/spark#40623 (Since Spark 3.4.0) - apache/spark#41568 (Since Spark 3.5.0) - apache/spark#48965 (GA Since Spark 4.0.0) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #103 from dongjoon-hyun/SPARK-51986. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 1ab93ea commit bb8b9fb

File tree

5 files changed

+114
-2
lines changed

5 files changed

+114
-2
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,19 @@ public actor DataFrame: Sendable {
4343
/// - Parameters:
4444
/// - spark: A `SparkSession` instance to use.
4545
/// - sqlText: A SQL statement.
46-
init(spark: SparkSession, sqlText: String) async throws {
46+
/// - posArgs: An array of strings.
47+
init(spark: SparkSession, sqlText: String, _ posArgs: [Sendable]? = nil) async throws {
4748
self.spark = spark
48-
self.plan = sqlText.toSparkConnectPlan
49+
if let posArgs {
50+
self.plan = sqlText.toSparkConnectPlan(posArgs)
51+
} else {
52+
self.plan = sqlText.toSparkConnectPlan
53+
}
54+
}
55+
56+
init(spark: SparkSession, sqlText: String, _ args: [String: Sendable]) async throws {
57+
self.spark = spark
58+
self.plan = sqlText.toSparkConnectPlan(args)
4959
}
5060

5161
public func getPlan() -> Sendable {

Sources/SparkConnect/Extension.swift

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,74 @@ extension String {
3131
return plan
3232
}
3333

34+
func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan {
35+
var sql = Spark_Connect_SQL()
36+
sql.query = self
37+
sql.posArguments = posArguments.map {
38+
var literal = ExpressionLiteral()
39+
switch $0 {
40+
case let value as Bool:
41+
literal.boolean = value
42+
case let value as Int8:
43+
literal.byte = Int32(value)
44+
case let value as Int16:
45+
literal.short = Int32(value)
46+
case let value as Int32:
47+
literal.integer = value
48+
case let value as Int64:
49+
literal.long = value
50+
case let value as Int:
51+
literal.long = Int64(value)
52+
case let value as String:
53+
literal.string = value
54+
default:
55+
literal.string = $0 as! String
56+
}
57+
var expr = Spark_Connect_Expression()
58+
expr.literal = literal
59+
return expr
60+
}
61+
var relation = Relation()
62+
relation.sql = sql
63+
var plan = Plan()
64+
plan.opType = Plan.OneOf_OpType.root(relation)
65+
return plan
66+
}
67+
68+
func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan {
69+
var sql = Spark_Connect_SQL()
70+
sql.query = self
71+
sql.namedArguments = namedArguments.mapValues { value in
72+
var literal = ExpressionLiteral()
73+
switch value {
74+
case let value as Bool:
75+
literal.boolean = value
76+
case let value as Int8:
77+
literal.byte = Int32(value)
78+
case let value as Int16:
79+
literal.short = Int32(value)
80+
case let value as Int32:
81+
literal.integer = value
82+
case let value as Int64:
83+
literal.long = value
84+
case let value as Int:
85+
literal.long = Int64(value)
86+
case let value as String:
87+
literal.string = value
88+
default:
89+
literal.string = value as! String
90+
}
91+
var expr = Spark_Connect_Expression()
92+
expr.literal = literal
93+
return expr
94+
}
95+
var relation = Relation()
96+
relation.sql = sql
97+
var plan = Plan()
98+
plan.opType = Plan.OneOf_OpType.root(relation)
99+
return plan
100+
}
101+
34102
/// Get a `UserContext` instance from a string.
35103
var toUserContext: UserContext {
36104
var context = UserContext()

Sources/SparkConnect/SparkSession.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,26 @@ public actor SparkSession {
112112
return try await DataFrame(spark: self, sqlText: sqlText)
113113
}
114114

115+
/// Executes a SQL query substituting positional parameters by the given arguments, returning the
116+
/// result as a `DataFrame`.
117+
/// - Parameters:
118+
/// - sqlText: A SQL statement with positional parameters to execute.
119+
/// - args: ``Sendable`` values that can be converted to SQL literal expressions.
120+
/// - Returns: A ``DataFrame``.
121+
public func sql(_ sqlText: String, _ args: Sendable...) async throws -> DataFrame {
122+
return try await DataFrame(spark: self, sqlText: sqlText, args)
123+
}
124+
125+
/// Executes a SQL query substituting named parameters by the given arguments, returning the
126+
/// result as a `DataFrame`.
127+
/// - Parameters:
128+
/// - sqlText: A SQL statement with named parameters to execute.
129+
/// - args: A dictionary with key string and ``Sendable`` value.
130+
/// - Returns: A ``DataFrame``.
131+
public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> DataFrame {
132+
return try await DataFrame(spark: self, sqlText: sqlText, args)
133+
}
134+
115135
/// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a
116136
/// `DataFrame`
117137
public var read: DataFrameReader {

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ typealias Drop = Spark_Connect_Drop
2828
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2929
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
3030
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
31+
typealias ExpressionLiteral = Spark_Connect_Expression.Literal
3132
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
3233
typealias Filter = Spark_Connect_Filter
3334
typealias GroupType = Spark_Connect_Aggregate.GroupType

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ struct SparkSessionTests {
7676
await spark.stop()
7777
}
7878

79+
#if !os(Linux)
80+
@Test
81+
func sql() async throws {
82+
let spark = try await SparkSession.builder.getOrCreate()
83+
let expected = [Row(true, 1, "a")]
84+
if await spark.version.starts(with: "4.") {
85+
#expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected)
86+
#expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() == expected)
87+
}
88+
await spark.stop()
89+
}
90+
#endif
91+
7992
@Test
8093
func table() async throws {
8194
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")

0 commit comments

Comments
 (0)