Skip to content

Commit 6d81f69

Browse files
committed
[SPARK-51570] Support filter/where for DataFrame
1 parent b70eeb7 commit 6d81f69

File tree

5 files changed

+49
-2
lines changed

5 files changed

+49
-2
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,20 @@ public actor DataFrame: Sendable {
219219
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
220220
}
221221

222+
/// Return a new ``DataFrame`` with filtered rows using the given expression.
223+
/// - Parameter conditionExpr: A string to filter.
224+
/// - Returns: A sorted ``DataFrame``
225+
public func filter(_ conditionExpr: String) -> DataFrame {
226+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr))
227+
}
228+
229+
/// Return a new ``DataFrame`` with filtered rows using the given expression.
230+
/// - Parameter conditionExpr: A string to filter.
231+
/// - Returns: A sorted ``DataFrame``
232+
public func `where`(_ conditionExpr: String) -> DataFrame {
233+
return filter(conditionExpr)
234+
}
235+
222236
/// Return a new ``DataFrame`` sorted by the specified column(s).
223237
/// - Parameter cols: Column names.
224238
/// - Returns: A sorted ``DataFrame``

Sources/SparkConnect/Extension.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ extension String {
5151
attribute.unparsedIdentifier = self
5252
return attribute
5353
}
54+
55+
var toExpressionString: ExpressionString {
56+
var expression = ExpressionString()
57+
expression.expression = self
58+
return expression
59+
}
5460
}
5561

5662
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,17 @@ public actor SparkConnectClient {
307307
return plan
308308
}
309309

310+
static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan {
311+
var filter = Filter()
312+
filter.input = child
313+
filter.condition.expressionString = conditionExpr.toExpressionString
314+
var relation = Relation()
315+
relation.filter = filter
316+
var plan = Plan()
317+
plan.opType = .root(relation)
318+
return plan
319+
}
320+
310321
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
311322
var sort = Sort()
312323
sort.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
2121
typealias ConfigRequest = Spark_Connect_ConfigRequest
2222
typealias DataType = Spark_Connect_DataType
2323
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
24-
typealias Plan = Spark_Connect_Plan
25-
typealias Project = Spark_Connect_Project
24+
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
25+
typealias Filter = Spark_Connect_Filter
2626
typealias KeyValue = Spark_Connect_KeyValue
2727
typealias Limit = Spark_Connect_Limit
2828
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
29+
typealias Plan = Spark_Connect_Plan
30+
typealias Project = Spark_Connect_Project
2931
typealias Range = Spark_Connect_Range
3032
typealias Relation = Spark_Connect_Relation
3133
typealias SparkConnectService = Spark_Connect_SparkConnectService

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@ struct DataFrameTests {
107107
await spark.stop()
108108
}
109109

110+
@Test
111+
func filter() async throws {
112+
let spark = try await SparkSession.builder.getOrCreate()
113+
#expect(try await spark.range(2025).filter("id % 2 == 0").count() == 1013)
114+
await spark.stop()
115+
}
116+
117+
@Test
118+
func `where`() async throws {
119+
let spark = try await SparkSession.builder.getOrCreate()
120+
#expect(try await spark.range(2025).where("id % 2 == 1").count() == 1012)
121+
await spark.stop()
122+
}
123+
110124
@Test
111125
func limit() async throws {
112126
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)