diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 1fbf2b4..8e96a38 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -269,6 +269,13 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) } + /// Projects a set of expressions and returns a new ``DataFrame``. + /// - Parameter exprs: Expression strings + /// - Returns: A ``DataFrame`` with subset of columns. + public func selectExpr(_ exprs: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getProjectExprs(self.plan.root, exprs)) + } + /// Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column name. /// - Parameter cols: Column names /// - Returns: A ``DataFrame`` with subset of columns. diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index fa7c392..00663b0 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -390,6 +390,22 @@ public actor SparkConnectClient { return plan } + static func getProjectExprs(_ child: Relation, _ exprs: [String]) -> Plan { + var project = Project() + project.input = child + let expressions: [Spark_Connect_Expression] = exprs.map { + var expression = Spark_Connect_Expression() + expression.exprType = .expressionString($0.toExpressionString) + return expression + } + project.expressions = expressions + var relation = Relation() + relation.project = project + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getWithColumnRenamed(_ child: Relation, _ colsMap: [String: String]) -> Plan { var withColumnsRenamed = WithColumnsRenamed() withColumnsRenamed.input = child diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index e39e83b..327b009 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -208,6 +208,20 @@ struct DataFrameTests { try await #require(throws: Error.self) { let _ = try await spark.range(1).select("invalid").schema } + try await #require(throws: Error.self) { + let _ = try await spark.range(1).select("id + 1").schema + } + await spark.stop() + } + + @Test + func selectExpr() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let schema = try await spark.range(1).selectExpr("id + 1 as id2").schema + #expect( + schema + == #"{"struct":{"fields":[{"name":"id2","dataType":{"long":{}}}]}}"# + ) await spark.stop() } diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index d5a824b..fe972d0 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -84,6 +84,7 @@ struct SparkConnectClientTests { await client.stop() } +#if !os(Linux) // TODO: Enable this with the offical Spark 4 docker image @Test func jsonToDdl() async throws { let client = SparkConnectClient(remote: TEST_REMOTE) @@ -95,4 +96,5 @@ struct SparkConnectClientTests { } await client.stop() } +#endif }