Skip to content

Commit 672ce65

Browse files
committed
[SPARK-51570] Support filter/where for DataFrame
### What changes were proposed in this pull request? This PR aims o support `filter` and `where` for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No, this is a new addition. ### How was this patch tested? Pass the CIs and manual test on MacOS with Apache Spark 4.0.0 RC3. ``` $ sbin/start-connect-server.sh ``` ``` $ swift test --filter DataFrameTests ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #25 from dongjoon-hyun/SPARK-51570. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent bebd5e1 commit 672ce65

File tree

5 files changed

+49
-2
lines changed

5 files changed

+49
-2
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,20 @@ public actor DataFrame: Sendable {
225225
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
226226
}
227227

228+
/// Return a new ``DataFrame`` with filtered rows using the given expression.
229+
/// - Parameter conditionExpr: A string to filter.
230+
/// - Returns: A ``DataFrame`` with subset of rows.
231+
public func filter(_ conditionExpr: String) -> DataFrame {
232+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr))
233+
}
234+
235+
/// Return a new ``DataFrame`` with filtered rows using the given expression.
236+
/// - Parameter conditionExpr: A string to filter.
237+
/// - Returns: A ``DataFrame`` with subset of rows.
238+
public func `where`(_ conditionExpr: String) -> DataFrame {
239+
return filter(conditionExpr)
240+
}
241+
228242
/// Return a new ``DataFrame`` sorted by the specified column(s).
229243
/// - Parameter cols: Column names.
230244
/// - Returns: A sorted ``DataFrame``

Sources/SparkConnect/Extension.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ extension String {
5151
attribute.unparsedIdentifier = self
5252
return attribute
5353
}
54+
55+
var toExpressionString: ExpressionString {
56+
var expression = ExpressionString()
57+
expression.expression = self
58+
return expression
59+
}
5460
}
5561

5662
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,17 @@ public actor SparkConnectClient {
307307
return plan
308308
}
309309

310+
static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan {
311+
var filter = Filter()
312+
filter.input = child
313+
filter.condition.expressionString = conditionExpr.toExpressionString
314+
var relation = Relation()
315+
relation.filter = filter
316+
var plan = Plan()
317+
plan.opType = .root(relation)
318+
return plan
319+
}
320+
310321
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
311322
var sort = Sort()
312323
sort.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2121
typealias ConfigRequest = Spark_Connect_ConfigRequest
2222
typealias DataType = Spark_Connect_DataType
2323
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
24-
typealias Plan = Spark_Connect_Plan
25-
typealias Project = Spark_Connect_Project
24+
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
25+
typealias Filter = Spark_Connect_Filter
2626
typealias KeyValue = Spark_Connect_KeyValue
2727
typealias Limit = Spark_Connect_Limit
2828
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
29+
typealias Plan = Spark_Connect_Plan
30+
typealias Project = Spark_Connect_Project
2931
typealias Range = Spark_Connect_Range
3032
typealias Relation = Spark_Connect_Relation
3133
typealias SparkConnectService = Spark_Connect_SparkConnectService

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@ struct DataFrameTests {
107107
await spark.stop()
108108
}
109109

110+
@Test
111+
func filter() async throws {
112+
let spark = try await SparkSession.builder.getOrCreate()
113+
#expect(try await spark.range(2025).filter("id % 2 == 0").count() == 1013)
114+
await spark.stop()
115+
}
116+
117+
@Test
118+
func `where`() async throws {
119+
let spark = try await SparkSession.builder.getOrCreate()
120+
#expect(try await spark.range(2025).where("id % 2 == 1").count() == 1012)
121+
await spark.stop()
122+
}
123+
110124
@Test
111125
func limit() async throws {
112126
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)