Skip to content

Commit 8f10457

Browse files
committed
Add sort, orderBy
1 parent 70f7280 commit 8f10457

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,18 @@ public actor DataFrame: Sendable {
194194
}
195195

196196
public func select(_ cols: String...) -> DataFrame {
197-
let plan = SparkConnectClient.getProject(self.plan.root, cols)
198-
return DataFrame(spark: self.spark, plan: plan)
197+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
198+
}
199+
200+
public func sort(_ cols: String...) -> DataFrame {
201+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
202+
}
203+
204+
public func orderBy(_ cols: String...) -> DataFrame {
205+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
199206
}
200207

201208
public func limit(_ n: Int32) -> DataFrame {
202-
let plan = SparkConnectClient.getLimit(self.plan.root, n)
203-
return DataFrame(spark: self.spark, plan: plan)
209+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
204210
}
205211
}

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ public actor SparkConnectClient {
269269
return plan
270270
}
271271

272+
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
273+
var sort = Sort()
274+
sort.input = child
275+
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
276+
var expression = Spark_Connect_Expression.SortOrder()
277+
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
278+
return expression
279+
}
280+
sort.order = expressions
281+
var relation = Relation()
282+
relation.sort = sort
283+
var plan = Plan()
284+
plan.opType = .root(relation)
285+
return plan
286+
}
287+
272288
static func getLimit(_ child: Relation, _ n: Int32) -> Plan {
273289
var limit = Limit()
274290
limit.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ typealias Limit = Spark_Connect_Limit
2828
typealias Range = Spark_Connect_Range
2929
typealias Relation = Spark_Connect_Relation
3030
typealias SparkConnectService = Spark_Connect_SparkConnectService
31+
typealias Sort = Spark_Connect_Sort
3132
typealias UserContext = Spark_Connect_UserContext
3233
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ struct DataFrameTests {
116116
await spark.stop()
117117
}
118118

119+
@Test
120+
func sort() async throws {
121+
let spark = try await SparkSession.builder.getOrCreate()
122+
#expect(try await spark.range(10).sort("id").count() == 10)
123+
await spark.stop()
124+
}
125+
126+
@Test
127+
func orderBy() async throws {
128+
let spark = try await SparkSession.builder.getOrCreate()
129+
#expect(try await spark.range(10).orderBy("id").count() == 10)
130+
await spark.stop()
131+
}
132+
119133
@Test
120134
func table() async throws {
121135
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)