Skip to content

Commit 33b4b08

Browse files
committed
[SPARK-51689] Support DataFrameWriter
### What changes were proposed in this pull request? This PR aims to support `DataFrameWriter`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No because this is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs with the newly added test suite. Or, manual test. ``` $ swift test --filter DataFrameWriterTests Building for debugging... [13/13] Compiling SparkConnectTests DataFrameWriterTests.swift Build complete! (3.63s) Test Suite 'Selected tests' started at 2025-04-02 17:01:14.353. Test Suite 'SparkConnectPackageTests.xctest' started at 2025-04-02 17:01:14.354. Test Suite 'SparkConnectPackageTests.xctest' passed at 2025-04-02 17:01:14.354. Executed 0 tests, with 0 failures (0 unexpected) in 0.000 (0.000) seconds Test Suite 'Selected tests' passed at 2025-04-02 17:01:14.354. Executed 0 tests, with 0 failures (0 unexpected) in 0.000 (0.002) seconds 􀟈 Test run started. 􀄵 Testing Library Version: 102 (arm64e-apple-macos13.0) 􀟈 Suite DataFrameWriterTests started. 􀟈 Test orc() started. 􀟈 Test pathAlreadyExist() started. 􀟈 Test csv() started. 􀟈 Test overwrite() started. 􀟈 Test sortByBucketBy() started. 􀟈 Test json() started. 􀟈 Test save() started. 􀟈 Test partitionBy() started. 􀟈 Test parquet() started. 􁁛 Test sortByBucketBy() passed after 0.072 seconds. 􁁛 Test pathAlreadyExist() passed after 0.396 seconds. 􁁛 Test overwrite() passed after 0.504 seconds. 􁁛 Test orc() passed after 0.515 seconds. 􁁛 Test json() passed after 0.524 seconds. 􁁛 Test parquet() passed after 0.539 seconds. 􁁛 Test partitionBy() passed after 0.549 seconds. 􁁛 Test csv() passed after 0.569 seconds. 􁁛 Test save() passed after 1.001 seconds. 􁁛 Suite DataFrameWriterTests passed after 1.002 seconds. 􁁛 Test run with 9 tests passed after 1.002 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #37 from dongjoon-hyun/SPARK-51689. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 97cdaa5 commit 33b4b08

File tree

6 files changed

+371
-0
lines changed

6 files changed

+371
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ public actor DataFrame: Sendable {
5050
self.plan = sqlText.toSparkConnectPlan
5151
}
5252

53+
public func getPlan() -> Sendable {
54+
return self.plan
55+
}
56+
5357
/// Set the schema. This is used to store the analized schema response from `Spark Connect` server.
5458
/// - Parameter schema: <#schema description#>
5559
private func setSchema(_ schema: DataType) {
@@ -382,4 +386,11 @@ public actor DataFrame: Sendable {
382386
print(response.treeString.treeString)
383387
}
384388
}
389+
390+
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
391+
var write: DataFrameWriter {
392+
get {
393+
return DataFrameWriter(df: self)
394+
}
395+
}
385396
}
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Atomics
20+
import Foundation
21+
import GRPCCore
22+
import GRPCNIOTransportHTTP2
23+
import GRPCProtobuf
24+
import NIOCore
25+
import SwiftyTextTable
26+
import Synchronization
27+
28+
/// An interface used to write a `DataFrame` to external storage systems
29+
/// (e.g. file systems, key-value stores, etc). Use `DataFrame.write` to access this.
30+
public actor DataFrameWriter: Sendable {
31+
var source: String? = nil
32+
33+
var saveMode: String = "default"
34+
35+
// TODO: Case-insensitive Map
36+
var extraOptions: [String: String] = [:]
37+
38+
var partitioningColumns: [String]? = nil
39+
40+
var bucketColumnNames: [String]? = nil
41+
42+
var numBuckets: Int32? = nil
43+
44+
var sortColumnNames: [String]? = nil
45+
46+
var clusteringColumns: [String]? = nil
47+
48+
let df: DataFrame
49+
50+
init(df: DataFrame) {
51+
self.df = df
52+
}
53+
54+
/// Specifies the output data source format.
55+
/// - Parameter source: A string.
56+
/// - Returns: A `DataFrameReader`.
57+
public func format(_ source: String) -> DataFrameWriter {
58+
self.source = source
59+
return self
60+
}
61+
62+
/// Specifies the behavior when data or table already exists. Options include:
63+
/// `overwrite`, `append`, `ignore`, `error` or `errorifexists` (default).
64+
///
65+
/// - Parameter saveMode: A string for save mode.
66+
/// - Returns: A `DataFrameWriter`
67+
public func mode(_ saveMode: String) -> DataFrameWriter {
68+
self.saveMode = saveMode
69+
return self
70+
}
71+
72+
/// Adds an output option for the underlying data source.
73+
/// - Parameters:
74+
/// - key: A key string.
75+
/// - value: A value string.
76+
/// - Returns: A `DataFrameWriter`.
77+
public func option(_ key: String, _ value: String) -> DataFrameWriter {
78+
self.extraOptions[key] = value
79+
return self
80+
}
81+
82+
public func partitionBy(_ columns: String...) -> DataFrameWriter {
83+
self.partitioningColumns = columns
84+
return self
85+
}
86+
87+
public func bucketBy(numBuckets: Int32, _ columnNames: String...) -> DataFrameWriter {
88+
self.numBuckets = numBuckets
89+
self.bucketColumnNames = columnNames
90+
return self
91+
}
92+
93+
public func sortBy(_ columnNames: String...) -> DataFrameWriter {
94+
self.sortColumnNames = columnNames
95+
return self
96+
}
97+
98+
public func clusterBy(_ columnNames: String...) -> DataFrameWriter {
99+
self.clusteringColumns = columnNames
100+
return self
101+
}
102+
103+
/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
104+
/// key-value stores).
105+
public func save() async throws {
106+
return try await saveInternal(nil)
107+
}
108+
109+
/// Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
110+
/// local or distributed file system).
111+
/// - Parameter path: A path string.
112+
public func save(_ path: String) async throws {
113+
try await saveInternal(path)
114+
}
115+
116+
private func saveInternal(_ path: String?) async throws {
117+
var write = WriteOperation()
118+
let plan = await self.df.getPlan() as! Plan
119+
write.input = plan.root
120+
write.mode = self.saveMode.toSaveMode
121+
if let path = path {
122+
write.path = path
123+
}
124+
125+
// Cannot both be set
126+
// require(!(builder.hasPath && builder.hasTable))
127+
128+
if let source = self.source {
129+
write.source = source
130+
}
131+
if let sortColumnNames = self.sortColumnNames {
132+
write.sortColumnNames = sortColumnNames
133+
}
134+
if let partitioningColumns = self.partitioningColumns {
135+
write.partitioningColumns = partitioningColumns
136+
}
137+
if let clusteringColumns = self.clusteringColumns {
138+
write.clusteringColumns = clusteringColumns
139+
}
140+
if let numBuckets = self.numBuckets {
141+
var bucketBy = WriteOperation.BucketBy()
142+
bucketBy.numBuckets = numBuckets
143+
if let bucketColumnNames = self.bucketColumnNames {
144+
bucketBy.bucketColumnNames = bucketColumnNames
145+
}
146+
write.bucketBy = bucketBy
147+
}
148+
149+
for option in self.extraOptions {
150+
write.options[option.key] = option.value
151+
}
152+
153+
var command = Spark_Connect_Command()
154+
command.writeOperation = write
155+
156+
_ = try await df.spark.client.execute(df.spark.sessionID, command)
157+
}
158+
159+
/// Saves the content of the `DataFrame` in CSV format at the specified path.
160+
/// - Parameter path: A path string
161+
/// - Returns: A `DataFrame`.
162+
public func csv(_ path: String) async throws {
163+
self.source = "csv"
164+
return try await save(path)
165+
}
166+
167+
/// Saves the content of the `DataFrame` in JSON format at the specified path.
168+
/// - Parameter path: A path string
169+
/// - Returns: A `DataFrame`.
170+
public func json(_ path: String) async throws {
171+
self.source = "json"
172+
return try await save(path)
173+
}
174+
175+
/// Saves the content of the `DataFrame` in ORC format at the specified path.
176+
/// - Parameter path: A path string
177+
/// - Returns: A `DataFrame`.
178+
public func orc(_ path: String) async throws {
179+
self.source = "orc"
180+
return try await save(path)
181+
}
182+
183+
/// Saves the content of the `DataFrame` in Parquet format at the specified path.
184+
/// - Parameter path: A path string
185+
public func parquet(_ path: String) async throws {
186+
self.source = "parquet"
187+
return try await save(path)
188+
}
189+
190+
/// Saves the content of the `DataFrame` in a text file at the specified path.
191+
/// The DataFrame must have only one column that is of string type.
192+
/// Each row becomes a new line in the output file.
193+
///
194+
/// - Parameter path: A path string
195+
public func text(_ path: String) async throws {
196+
self.source = "text"
197+
return try await save(path)
198+
}
199+
}

Sources/SparkConnect/Extension.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ extension String {
6969
}
7070
return mode
7171
}
72+
73+
var toSaveMode: SaveMode {
74+
return switch self.lowercased() {
75+
case "append": SaveMode.append
76+
case "overwrite": SaveMode.overwrite
77+
case "error": SaveMode.errorIfExists
78+
case "errorIfExists": SaveMode.errorIfExists
79+
case "ignore": SaveMode.ignore
80+
default: SaveMode.errorIfExists
81+
}
82+
}
7283
}
7384

7485
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,4 +361,30 @@ public actor SparkConnectClient {
361361
plan.opType = .root(relation)
362362
return plan
363363
}
364+
365+
var result: [ExecutePlanResponse] = []
366+
private func addResponse(_ response: ExecutePlanResponse) {
367+
self.result.append(response)
368+
}
369+
370+
func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] {
371+
self.result.removeAll()
372+
try await withGRPCClient(
373+
transport: .http2NIOPosix(
374+
target: .dns(host: self.host, port: self.port),
375+
transportSecurity: .plaintext
376+
)
377+
) { client in
378+
let service = SparkConnectService.Client(wrapping: client)
379+
var plan = Plan()
380+
plan.opType = .command(command)
381+
try await service.executePlan(getExecutePlanRequest(sessionID, plan)) {
382+
response in
383+
for try await m in response.messages {
384+
await self.addResponse(m)
385+
}
386+
}
387+
}
388+
return result
389+
}
364390
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2020
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
21+
typealias Command = Spark_Connect_Command
2122
typealias ConfigRequest = Spark_Connect_ConfigRequest
2223
typealias DataSource = Spark_Connect_Read.DataSource
2324
typealias DataType = Spark_Connect_DataType
2425
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
2526
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
27+
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
2628
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
2729
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
2830
typealias Filter = Spark_Connect_Filter
@@ -35,9 +37,11 @@ typealias Project = Spark_Connect_Project
3537
typealias Range = Spark_Connect_Range
3638
typealias Read = Spark_Connect_Read
3739
typealias Relation = Spark_Connect_Relation
40+
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
3841
typealias SparkConnectService = Spark_Connect_SparkConnectService
3942
typealias Sort = Spark_Connect_Sort
4043
typealias StructType = Spark_Connect_DataType.Struct
4144
typealias UserContext = Spark_Connect_UserContext
4245
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
46+
typealias WriteOperation = Spark_Connect_WriteOperation
4347
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval

0 commit comments

Comments
 (0)