Skip to content

Commit 7597ad5

Browse files
committed
[SPARK-51508] Support collect(): [[String]] for DataFrame
1 parent 68067bf commit 7597ad5

File tree

5 files changed

+51
-11
lines changed

5 files changed

+51
-11
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public actor DataFrame: Sendable {
5858

5959
/// Add `Apache Arrow`'s `RecordBatch`s to the internal array.
6060
/// - Parameter batches: An array of ``RecordBatch``.
61-
private func addBathes(_ batches: [RecordBatch]) {
61+
private func addBatches(_ batches: [RecordBatch]) {
6262
self.batches.append(contentsOf: batches)
6363
}
6464

@@ -153,16 +153,35 @@ public actor DataFrame: Sendable {
153153
let arrowResult = ArrowReader.makeArrowReaderResult()
154154
_ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult)
155155
_ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult)
156-
await self.addBathes(arrowResult.batches)
156+
await self.addBatches(arrowResult.batches)
157157
}
158158
}
159159
}
160160
}
161161
}
162162

163-
/// This is designed not to support this feature in order to simplify the Swift client.
164-
public func collect() async throws {
165-
throw SparkConnectError.UnsupportedOperationException
163+
/// Execute the plan and return the result as ``[[String]]``.
164+
/// - Returns: ``[[String]]``
165+
public func collect() async throws -> [[String]] {
166+
try await execute()
167+
168+
var result: [[String]] = []
169+
for batch in self.batches {
170+
for i in 0..<batch.length {
171+
var values: [String] = []
172+
for column in batch.columns {
173+
let str = column.array as! AsString
174+
if column.data.isNull(i) {
175+
values.append("NULL")
176+
} else {
177+
values.append(str.asString(i))
178+
}
179+
}
180+
result.append(values)
181+
}
182+
}
183+
184+
return result
166185
}
167186

168187
/// Execute the plan and show the result.

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,11 @@ public actor SparkConnectClient {
275275
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
276276
var expression = Spark_Connect_Expression.SortOrder()
277277
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
278+
expression.direction = .ascending
278279
return expression
279280
}
280281
sort.order = expressions
282+
sort.isGlobal = true
281283
var relation = Relation()
282284
relation.sort = sort
283285
var plan = Plan()

Sources/SparkConnect/SparkSession.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,10 @@ public actor SparkSession {
4545
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
4646
init(_ connection: String, _ userID: String? = nil) {
4747
let processInfo = ProcessInfo.processInfo
48-
#if os(iOS) || os(watchOS) || os(tvOS)
49-
let userName = processInfo.environment["SPARK_USER"] ?? ""
50-
#elseif os(macOS) || os(Linux)
48+
#if os(macOS) || os(Linux)
5149
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
5250
#else
53-
assert(false, "Unsupported platform")
51+
let userName = processInfo.environment["SPARK_USER"] ?? ""
5452
#endif
5553
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
5654
self.conf = RuntimeConf(self.client)

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,25 @@ struct DataFrameTests {
125125
await spark.stop()
126126
}
127127

128+
#if !os(Linux)
128129
@Test
129130
func sort() async throws {
130131
let spark = try await SparkSession.builder.getOrCreate()
131-
#expect(try await spark.range(10).sort("id").count() == 10)
132+
let expected = (1...10).map{ [String($0)] }
133+
#expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
132134
await spark.stop()
133135
}
134136

135137
@Test
136138
func orderBy() async throws {
137139
let spark = try await SparkSession.builder.getOrCreate()
138-
#expect(try await spark.range(10).orderBy("id").count() == 10)
140+
let expected = (1...10).map{ [String($0)] }
141+
print(expected)
142+
print(try await spark.range(10, 0, -1).orderBy("id").collect())
143+
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
139144
await spark.stop()
140145
}
146+
#endif
141147

142148
@Test
143149
func table() async throws {
@@ -153,6 +159,17 @@ struct DataFrameTests {
153159
}
154160

155161
#if !os(Linux)
162+
@Test
163+
func collect() async throws {
164+
let spark = try await SparkSession.builder.getOrCreate()
165+
#expect(try await spark.range(0).collect().isEmpty)
166+
#expect(
167+
try await spark.sql(
168+
"SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')"
169+
).collect() == [["1", "true", "abc"], ["NULL", "NULL", "NULL"], ["3", "false", "def"]])
170+
await spark.stop()
171+
}
172+
156173
@Test
157174
func show() async throws {
158175
let spark = try await SparkSession.builder.getOrCreate()

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ struct SparkSessionTests {
4141

4242
@Test func userContext() async throws {
4343
let spark = try await SparkSession.builder.getOrCreate()
44+
#if os(macOS) || os(Linux)
4445
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
46+
#else
47+
let defaultUserContext = "".toUserContext
48+
#endif
4549
#expect(await spark.client.userContext == defaultUserContext)
4650
await spark.stop()
4751
}

0 commit comments

Comments
 (0)