Skip to content

Commit 98e9554

Browse files
committed
[SPARK-51508] Support collect(): [[String?]] for DataFrame
### What changes were proposed in this pull request? This PR aims to support `DataFrame.collect()` with the return type, an array of String array. ### Why are the changes needed? There are two main goals. 1. Provide one of the simplest implementations for `collect()` API. 2. Use it as a way to implement interim test coverage until we implement a more generic return type including **complex types**. ### Does this PR introduce _any_ user-facing change? No, this is not released yet. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #17 from dongjoon-hyun/SPARK-51508. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 68067bf commit 98e9554

File tree

5 files changed

+49
-11
lines changed

5 files changed

+49
-11
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public actor DataFrame: Sendable {
5858

5959
/// Add `Apache Arrow`'s `RecordBatch`s to the internal array.
6060
/// - Parameter batches: An array of ``RecordBatch``.
61-
private func addBathes(_ batches: [RecordBatch]) {
61+
private func addBatches(_ batches: [RecordBatch]) {
6262
self.batches.append(contentsOf: batches)
6363
}
6464

@@ -153,16 +153,35 @@ public actor DataFrame: Sendable {
153153
let arrowResult = ArrowReader.makeArrowReaderResult()
154154
_ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult)
155155
_ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult)
156-
await self.addBathes(arrowResult.batches)
156+
await self.addBatches(arrowResult.batches)
157157
}
158158
}
159159
}
160160
}
161161
}
162162

163-
/// This is designed not to support this feature in order to simplify the Swift client.
164-
public func collect() async throws {
165-
throw SparkConnectError.UnsupportedOperationException
163+
/// Execute the plan and return the result as ``[[String?]]``.
164+
/// - Returns: ``[[String?]]``
165+
public func collect() async throws -> [[String?]] {
166+
try await execute()
167+
168+
var result: [[String?]] = []
169+
for batch in self.batches {
170+
for i in 0..<batch.length {
171+
var values: [String?] = []
172+
for column in batch.columns {
173+
let str = column.array as! AsString
174+
if column.data.isNull(i) {
175+
values.append(nil)
176+
} else {
177+
values.append(str.asString(i))
178+
}
179+
}
180+
result.append(values)
181+
}
182+
}
183+
184+
return result
166185
}
167186

168187
/// Execute the plan and show the result.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,11 @@ public actor SparkConnectClient {
275275
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
276276
var expression = Spark_Connect_Expression.SortOrder()
277277
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
278+
expression.direction = .ascending
278279
return expression
279280
}
280281
sort.order = expressions
282+
sort.isGlobal = true
281283
var relation = Relation()
282284
relation.sort = sort
283285
var plan = Plan()

Sources/SparkConnect/SparkSession.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,10 @@ public actor SparkSession {
4545
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
4646
init(_ connection: String, _ userID: String? = nil) {
4747
let processInfo = ProcessInfo.processInfo
48-
#if os(iOS) || os(watchOS) || os(tvOS)
49-
let userName = processInfo.environment["SPARK_USER"] ?? ""
50-
#elseif os(macOS) || os(Linux)
48+
#if os(macOS) || os(Linux)
5149
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
5250
#else
53-
assert(false, "Unsupported platform")
51+
let userName = processInfo.environment["SPARK_USER"] ?? ""
5452
#endif
5553
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
5654
self.conf = RuntimeConf(self.client)

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,23 @@ struct DataFrameTests {
125125
await spark.stop()
126126
}
127127

128+
#if !os(Linux)
128129
@Test
129130
func sort() async throws {
130131
let spark = try await SparkSession.builder.getOrCreate()
131-
#expect(try await spark.range(10).sort("id").count() == 10)
132+
let expected = (1...10).map{ [String($0)] }
133+
#expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
132134
await spark.stop()
133135
}
134136

135137
@Test
136138
func orderBy() async throws {
137139
let spark = try await SparkSession.builder.getOrCreate()
138-
#expect(try await spark.range(10).orderBy("id").count() == 10)
140+
let expected = (1...10).map{ [String($0)] }
141+
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
139142
await spark.stop()
140143
}
144+
#endif
141145

142146
@Test
143147
func table() async throws {
@@ -153,6 +157,17 @@ struct DataFrameTests {
153157
}
154158

155159
#if !os(Linux)
160+
@Test
161+
func collect() async throws {
162+
let spark = try await SparkSession.builder.getOrCreate()
163+
#expect(try await spark.range(0).collect().isEmpty)
164+
#expect(
165+
try await spark.sql(
166+
"SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')"
167+
).collect() == [["1", "true", "abc"], [nil, nil, nil], ["3", "false", "def"]])
168+
await spark.stop()
169+
}
170+
156171
@Test
157172
func show() async throws {
158173
let spark = try await SparkSession.builder.getOrCreate()

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ struct SparkSessionTests {
4141

4242
@Test func userContext() async throws {
4343
let spark = try await SparkSession.builder.getOrCreate()
44+
#if os(macOS) || os(Linux)
4445
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
46+
#else
47+
let defaultUserContext = "".toUserContext
48+
#endif
4549
#expect(await spark.client.userContext == defaultUserContext)
4650
await spark.stop()
4751
}

0 commit comments

Comments
 (0)