From 0090672c235fc11eedf36783056bd5d7d83ebec4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 2 May 2025 10:33:34 -0700 Subject: [PATCH 1/2] [SPARK-51986] Support `Parameterized SQL queries` in `sql` API --- Sources/SparkConnect/DataFrame.swift | 14 +++- Sources/SparkConnect/Extension.swift | 68 +++++++++++++++++++ Sources/SparkConnect/SparkSession.swift | 20 ++++++ Sources/SparkConnect/TypeAliases.swift | 1 + .../SparkConnectTests/SparkSessionTests.swift | 13 ++++ 5 files changed, 114 insertions(+), 2 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 5531917..cbe4793 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -43,9 +43,19 @@ public actor DataFrame: Sendable { /// - Parameters: /// - spark: A `SparkSession` instance to use. /// - sqlText: A SQL statement. - init(spark: SparkSession, sqlText: String) async throws { + /// - posArgs: An array of strings. + init(spark: SparkSession, sqlText: String, _ posArgs: [Sendable]? = nil) async throws { self.spark = spark - self.plan = sqlText.toSparkConnectPlan + if let posArgs { + self.plan = sqlText.toSparkConnectPlan(posArgs) + } else { + self.plan = sqlText.toSparkConnectPlan + } + } + + init(spark: SparkSession, sqlText: String, _ args: [String: Sendable]) async throws { + self.spark = spark + self.plan = sqlText.toSparkConnectPlan(args) } public func getPlan() -> Sendable { diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 5d75b3d..e841fa4 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -31,6 +31,74 @@ extension String { return plan } + func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan { + var sql = Spark_Connect_SQL() + sql.query = self + sql.posArguments = posArguments.map { + var literal = ExpressionLiteral() + switch $0 { + case let value as Bool: + literal.boolean = value + case let value as Int8: + literal.byte = Int32(value) + case let value as Int16: + literal.short = Int32(value) + case let value as Int32: + literal.integer = value + case let value as Int64: + literal.long = value + case let value as Int: + literal.long = Int64(value) + case let value as String: + literal.string = value + default: + literal.string = $0 as! String + } + var expr = Spark_Connect_Expression() + expr.literal = literal + return expr + } + var relation = Relation() + relation.sql = sql + var plan = Plan() + plan.opType = Plan.OneOf_OpType.root(relation) + return plan + } + + func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan { + var sql = Spark_Connect_SQL() + sql.query = self + sql.namedArguments = namedArguments.mapValues { value in + var literal = ExpressionLiteral() + switch value { + case let value as Bool: + literal.boolean = value + case let value as Int8: + literal.byte = Int32(value) + case let value as Int16: + literal.short = Int32(value) + case let value as Int32: + literal.integer = value + case let value as Int64: + literal.long = value + case let value as Int: + literal.long = Int64(value) + case let value as String: + literal.string = value + default: + literal.string = value as! String + } + var expr = Spark_Connect_Expression() + expr.literal = literal + return expr + } + var relation = Relation() + relation.sql = sql + var plan = Plan() + plan.opType = Plan.OneOf_OpType.root(relation) + return plan + } + /// Get a `UserContext` instance from a string. var toUserContext: UserContext { var context = UserContext() diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index b06370e..2ae4705 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -112,6 +112,26 @@ public actor SparkSession { return try await DataFrame(spark: self, sqlText: sqlText) } + /// Executes a SQL query substituting positional parameters by the given arguments, returning the + /// result as a `DataFrame`. + /// - Parameters: + /// - sqlText: A SQL statement with positional parameters to execute. + /// - args: An array of strings that can be converted to SQL literal expressions. + /// - Returns: A ``DataFrame``. + public func sql(_ sqlText: String, _ args: Sendable...) async throws -> DataFrame { + return try await DataFrame(spark: self, sqlText: sqlText, args) + } + + /// Executes a SQL query substituting named parameters by the given arguments, returning the + /// result as a `DataFrame`. + /// - Parameters: + /// - sqlText: A SQL statement with named parameters to execute. + /// - args: A dictionary with key string and values. + /// - Returns: A ``DataFrame``. + public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> DataFrame { + return try await DataFrame(spark: self, sqlText: sqlText, args) + } + /// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a /// `DataFrame` public var read: DataFrameReader { diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 60f0fb8..41547f8 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -28,6 +28,7 @@ typealias Drop = Spark_Connect_Drop typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode +typealias ExpressionLiteral = Spark_Connect_Expression.Literal typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter typealias GroupType = Spark_Connect_Aggregate.GroupType diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 2bc887e..69f0aee 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -76,6 +76,19 @@ struct SparkSessionTests { await spark.stop() } +#if !os(Linux) + @Test + func sql() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = [Row(true, 1, "a")] + if await spark.version.starts(with: "4.") { + #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) + #expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() == expected) + } + await spark.stop() + } +#endif + @Test func table() async throws { let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") From ab8f43e33a1a8dc450a84b66a9f2bfed2f60a760 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 2 May 2025 14:34:41 -0700 Subject: [PATCH 2/2] Address comment --- Sources/SparkConnect/SparkSession.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 2ae4705..ebaf190 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -116,7 +116,7 @@ public actor SparkSession { /// result as a `DataFrame`. /// - Parameters: /// - sqlText: A SQL statement with positional parameters to execute. - /// - args: An array of strings that can be converted to SQL literal expressions. + /// - args: ``Sendable`` values that can be converted to SQL literal expressions. /// - Returns: A ``DataFrame``. public func sql(_ sqlText: String, _ args: Sendable...) async throws -> DataFrame { return try await DataFrame(spark: self, sqlText: sqlText, args) @@ -126,7 +126,7 @@ public actor SparkSession { /// result as a `DataFrame`. /// - Parameters: /// - sqlText: A SQL statement with named parameters to execute. - /// - args: A dictionary with key string and values. + /// - args: A dictionary with key string and ``Sendable`` value. /// - Returns: A ``DataFrame``. public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> DataFrame { return try await DataFrame(spark: self, sqlText: sqlText, args)