diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 81c92e4..8ee8f96 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -219,6 +219,20 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) } + /// Return a new ``DataFrame`` with filtered rows using the given expression. + /// - Parameter conditionExpr: A string to filter. + /// - Returns: A ``DataFrame`` with subset of rows. + public func filter(_ conditionExpr: String) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr)) + } + + /// Return a new ``DataFrame`` with filtered rows using the given expression. + /// - Parameter conditionExpr: A string to filter. + /// - Returns: A ``DataFrame`` with subset of rows. + public func `where`(_ conditionExpr: String) -> DataFrame { + return filter(conditionExpr) + } + /// Return a new ``DataFrame`` sorted by the specified column(s). /// - Parameter cols: Column names. /// - Returns: A sorted ``DataFrame`` diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 3b2f839..848a96e 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -51,6 +51,12 @@ extension String { attribute.unparsedIdentifier = self return attribute } + + var toExpressionString: ExpressionString { + var expression = ExpressionString() + expression.expression = self + return expression + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index aefd844..6b980cf 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -307,6 +307,17 @@ public actor SparkConnectClient { return plan } + static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan { + var filter = Filter() + filter.input = child + filter.condition.expressionString = conditionExpr.toExpressionString + var relation = Relation() + relation.filter = filter + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getSort(_ child: Relation, _ cols: [String]) -> Plan { var sort = Sort() sort.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 92aa78e..2823e5f 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -21,11 +21,13 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataType = Spark_Connect_DataType typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest -typealias Plan = Spark_Connect_Plan -typealias Project = Spark_Connect_Project +typealias ExpressionString = Spark_Connect_Expression.ExpressionString +typealias Filter = Spark_Connect_Filter typealias KeyValue = Spark_Connect_KeyValue typealias Limit = Spark_Connect_Limit typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze +typealias Plan = Spark_Connect_Plan +typealias Project = Spark_Connect_Project typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 552374d..c7170d3 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -107,6 +107,20 @@ struct DataFrameTests { await spark.stop() } + @Test + func filter() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(2025).filter("id % 2 == 0").count() == 1013) + await spark.stop() + } + + @Test + func `where`() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(2025).where("id % 2 == 1").count() == 1012) + await spark.stop() + } + @Test func limit() async throws { let spark = try await SparkSession.builder.getOrCreate()