Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
}

/// Return a new ``DataFrame`` with filtered rows using the given expression.
/// - Parameter conditionExpr: A string to filter.
/// - Returns: A ``DataFrame`` with subset of rows.
public func filter(_ conditionExpr: String) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr))
}

/// Return a new ``DataFrame`` with filtered rows using the given expression.
/// - Parameter conditionExpr: A string to filter.
/// - Returns: A ``DataFrame`` with subset of rows.
public func `where`(_ conditionExpr: String) -> DataFrame {
return filter(conditionExpr)
}

/// Return a new ``DataFrame`` sorted by the specified column(s).
/// - Parameter cols: Column names.
/// - Returns: A sorted ``DataFrame``
Expand Down
6 changes: 6 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ extension String {
attribute.unparsedIdentifier = self
return attribute
}

var toExpressionString: ExpressionString {
var expression = ExpressionString()
expression.expression = self
return expression
}
}

extension [String: String] {
Expand Down
11 changes: 11 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ public actor SparkConnectClient {
return plan
}

static func getFilter(_ child: Relation, _ conditionExpr: String) -> Plan {
var filter = Filter()
filter.input = child
filter.condition.expressionString = conditionExpr.toExpressionString
var relation = Relation()
relation.filter = filter
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
Expand Down
6 changes: 4 additions & 2 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
typealias ConfigRequest = Spark_Connect_ConfigRequest
typealias DataType = Spark_Connect_DataType
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias Plan = Spark_Connect_Plan
typealias Project = Spark_Connect_Project
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
typealias KeyValue = Spark_Connect_KeyValue
typealias Limit = Spark_Connect_Limit
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
typealias Plan = Spark_Connect_Plan
typealias Project = Spark_Connect_Project
typealias Range = Spark_Connect_Range
typealias Relation = Spark_Connect_Relation
typealias SparkConnectService = Spark_Connect_SparkConnectService
Expand Down
14 changes: 14 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func filter() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(2025).filter("id % 2 == 0").count() == 1013)
await spark.stop()
}

@Test
func `where`() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(2025).where("id % 2 == 1").count() == 1012)
await spark.stop()
}

@Test
func limit() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading