Skip to content

Commit 68067bf

Browse files
committed
[SPARK-51504] Support select/limit/sort/orderBy/isEmpty for DataFrame
### What changes were proposed in this pull request? This PR aims to support `select`, `limit`, `sort`, `orderBy`, and `isEmpty` of `DataFrame` API. ### Why are the changes needed? To support `select/limit/sort/orderBy/isEmpty` API. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #16 from dongjoon-hyun/SPARK-51504. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent af298fc commit 68067bf

File tree

5 files changed

+159
-1
lines changed

5 files changed

+159
-1
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public actor DataFrame: Sendable {
3636
/// - Parameters:
3737
/// - spark: A ``SparkSession`` instance to use.
3838
/// - plan: A plan to execute.
39-
init(spark: SparkSession, plan: Plan) async throws {
39+
init(spark: SparkSession, plan: Plan) {
4040
self.spark = spark
4141
self.plan = plan
4242
}
@@ -192,4 +192,38 @@ public actor DataFrame: Sendable {
192192
print(table.render())
193193
}
194194
}
195+
196+
/// Projects a set of expressions and returns a new ``DataFrame``.
197+
/// - Parameter cols: Column names
198+
/// - Returns: A ``DataFrame`` with subset of columns.
199+
public func select(_ cols: String...) -> DataFrame {
200+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
201+
}
202+
203+
/// Return a new ``DataFrame`` sorted by the specified column(s).
204+
/// - Parameter cols: Column names.
205+
/// - Returns: A sorted ``DataFrame``
206+
public func sort(_ cols: String...) -> DataFrame {
207+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
208+
}
209+
210+
/// Return a new ``DataFrame`` sorted by the specified column(s).
211+
/// - Parameter cols: Column names.
212+
/// - Returns: A sorted ``DataFrame``
213+
public func orderBy(_ cols: String...) -> DataFrame {
214+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
215+
}
216+
217+
/// Limits the result count to the number specified.
218+
/// - Parameter n: Number of records to return. Will return this number of records or all records if the ``DataFrame`` contains less than this number of records.
219+
/// - Returns: A subset of the records
220+
public func limit(_ n: Int32) -> DataFrame {
221+
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
222+
}
223+
224+
/// Checks if the ``DataFrame`` is empty and returns a boolean value.
225+
/// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise.
226+
public func isEmpty() async throws -> Bool {
227+
return try await select().limit(1).count() == 0
228+
}
195229
}

Sources/SparkConnect/Extension.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ extension String {
4545
keyValue.key = self
4646
return keyValue
4747
}
48+
49+
var toUnresolvedAttribute: UnresolvedAttribute {
50+
var attribute = UnresolvedAttribute()
51+
attribute.unparsedIdentifier = self
52+
return attribute
53+
}
4854
}
4955

5056
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,47 @@ public actor SparkConnectClient {
252252
request.analyze = .schema(schema)
253253
return request
254254
}
255+
256+
static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
257+
var project = Project()
258+
project.input = child
259+
let expressions: [Spark_Connect_Expression] = cols.map {
260+
var expression = Spark_Connect_Expression()
261+
expression.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
262+
return expression
263+
}
264+
project.expressions = expressions
265+
var relation = Relation()
266+
relation.project = project
267+
var plan = Plan()
268+
plan.opType = .root(relation)
269+
return plan
270+
}
271+
272+
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
273+
var sort = Sort()
274+
sort.input = child
275+
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
276+
var expression = Spark_Connect_Expression.SortOrder()
277+
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
278+
return expression
279+
}
280+
sort.order = expressions
281+
var relation = Relation()
282+
relation.sort = sort
283+
var plan = Plan()
284+
plan.opType = .root(relation)
285+
return plan
286+
}
287+
288+
static func getLimit(_ child: Relation, _ n: Int32) -> Plan {
289+
var limit = Limit()
290+
limit.input = child
291+
limit.limit = n
292+
var relation = Relation()
293+
relation.limit = limit
294+
var plan = Plan()
295+
plan.opType = .root(relation)
296+
return plan
297+
}
255298
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest
2222
typealias DataType = Spark_Connect_DataType
2323
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
2424
typealias Plan = Spark_Connect_Plan
25+
typealias Project = Spark_Connect_Project
2526
typealias KeyValue = Spark_Connect_KeyValue
27+
typealias Limit = Spark_Connect_Limit
2628
typealias Range = Spark_Connect_Range
2729
typealias Relation = Spark_Connect_Relation
2830
typealias SparkConnectService = Spark_Connect_SparkConnectService
31+
typealias Sort = Spark_Connect_Sort
2932
typealias UserContext = Spark_Connect_UserContext
33+
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,77 @@ struct DataFrameTests {
6868
await spark.stop()
6969
}
7070

71+
@Test
72+
func selectNone() async throws {
73+
let spark = try await SparkSession.builder.getOrCreate()
74+
let emptySchema = try await spark.range(1).select().schema()
75+
#expect(emptySchema == #"{"struct":{}}"#)
76+
await spark.stop()
77+
}
78+
79+
@Test
80+
func select() async throws {
81+
let spark = try await SparkSession.builder.getOrCreate()
82+
let schema = try await spark.range(1).select("id").schema()
83+
#expect(
84+
schema
85+
== #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"#
86+
)
87+
await spark.stop()
88+
}
89+
90+
@Test
91+
func selectMultipleColumns() async throws {
92+
let spark = try await SparkSession.builder.getOrCreate()
93+
let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema()
94+
#expect(
95+
schema
96+
== #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"#
97+
)
98+
await spark.stop()
99+
}
100+
101+
@Test
102+
func selectInvalidColumn() async throws {
103+
let spark = try await SparkSession.builder.getOrCreate()
104+
try await #require(throws: Error.self) {
105+
let _ = try await spark.range(1).select("invalid").schema()
106+
}
107+
await spark.stop()
108+
}
109+
110+
@Test
111+
func limit() async throws {
112+
let spark = try await SparkSession.builder.getOrCreate()
113+
#expect(try await spark.range(10).limit(0).count() == 0)
114+
#expect(try await spark.range(10).limit(1).count() == 1)
115+
#expect(try await spark.range(10).limit(2).count() == 2)
116+
#expect(try await spark.range(10).limit(15).count() == 10)
117+
await spark.stop()
118+
}
119+
120+
@Test
121+
func isEmpty() async throws {
122+
let spark = try await SparkSession.builder.getOrCreate()
123+
#expect(try await spark.range(0).isEmpty())
124+
#expect(!(try await spark.range(1).isEmpty()))
125+
await spark.stop()
126+
}
127+
128+
@Test
129+
func sort() async throws {
130+
let spark = try await SparkSession.builder.getOrCreate()
131+
#expect(try await spark.range(10).sort("id").count() == 10)
132+
await spark.stop()
133+
}
134+
135+
@Test
136+
func orderBy() async throws {
137+
let spark = try await SparkSession.builder.getOrCreate()
138+
#expect(try await spark.range(10).orderBy("id").count() == 10)
139+
await spark.stop()
140+
}
141+
71142
@Test
72143
func table() async throws {
73144
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)