From 76b00def1788409fd247057ff3f4b2eec8bb772b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 27 Mar 2025 12:44:25 -0700 Subject: [PATCH 1/2] [SPARK-51689] Support `DataFrameWriter` --- Sources/SparkConnect/DataFrame.swift | 11 + Sources/SparkConnect/DataFrameWriter.swift | 199 ++++++++++++++++++ Sources/SparkConnect/Extension.swift | 11 + Sources/SparkConnect/SparkConnectClient.swift | 26 +++ Sources/SparkConnect/TypeAliases.swift | 4 + .../DataFrameWriterTests.swift | 117 ++++++++++ 6 files changed, 368 insertions(+) create mode 100644 Sources/SparkConnect/DataFrameWriter.swift create mode 100644 Tests/SparkConnectTests/DataFrameWriterTests.swift diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 2588ef7..b96e02d 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -50,6 +50,10 @@ public actor DataFrame: Sendable { self.plan = sqlText.toSparkConnectPlan } + public func getPlan() -> Sendable { + return self.plan + } + /// Set the schema. This is used to store the analized schema response from `Spark Connect` server. /// - Parameter schema: <#schema description#> private func setSchema(_ schema: DataType) { @@ -382,4 +386,11 @@ public actor DataFrame: Sendable { print(response.treeString.treeString) } } + + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. + var write: DataFrameWriter { + get { + return DataFrameWriter(df: self) + } + } } diff --git a/Sources/SparkConnect/DataFrameWriter.swift b/Sources/SparkConnect/DataFrameWriter.swift new file mode 100644 index 0000000..6ebc514 --- /dev/null +++ b/Sources/SparkConnect/DataFrameWriter.swift @@ -0,0 +1,199 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +import Atomics +import Foundation +import GRPCCore +import GRPCNIOTransportHTTP2 +import GRPCProtobuf +import NIOCore +import SwiftyTextTable +import Synchronization + +/// An interface used to write a `DataFrame` to external storage systems +/// (e.g. file systems, key-value stores, etc). Use `DataFrame.write` to access this. +public actor DataFrameWriter: Sendable { + var source: String? = nil + + var saveMode: String = "default" + + // TODO: Case-insensitive Map + var extraOptions: [String: String] = [:] + + var partitioningColumns: [String]? = nil + + var bucketColumnNames: [String]? = nil + + var numBuckets: Int32? = nil + + var sortColumnNames: [String]? = nil + + var clusteringColumns: [String]? = nil + + let df: DataFrame + + init(df: DataFrame) { + self.df = df + } + + /// Specifies the output data source format. + /// - Parameter source: A string. + /// - Returns: A `DataFrameReader`. + public func format(_ source: String) -> DataFrameWriter { + self.source = source + return self + } + + /// Specifies the behavior when data or table already exists. Options include: + /// `overwrite`, `append`, `ignore`, `error` or `errorifexists` (default). + /// + /// - Parameter saveMode: A string for save mode. + /// - Returns: A `DataFrameWriter` + public func mode(_ saveMode: String) -> DataFrameWriter { + self.saveMode = saveMode + return self + } + + /// Adds an output option for the underlying data source. + /// - Parameters: + /// - key: A key string. + /// - value: A value string. + /// - Returns: A `DataFrameWriter`. + public func option(_ key: String, _ value: String) -> DataFrameWriter { + self.extraOptions[key] = value + return self + } + + public func partitionBy(_ columns: String...) -> DataFrameWriter { + self.partitioningColumns = columns + return self + } + + public func bucketBy(numBuckets: Int32, _ columnNames: String...) -> DataFrameWriter { + self.numBuckets = numBuckets + self.bucketColumnNames = columnNames + return self + } + + public func sortBy(_ columnNames: String...) -> DataFrameWriter { + self.sortColumnNames = columnNames + return self + } + + public func clusterBy(_ columnNames: String...) -> DataFrameWriter { + self.clusteringColumns = columnNames + return self + } + + /// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external + /// key-value stores). + public func save() async throws { + return try await saveInternal(nil) + } + + /// Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a + /// local or distributed file system). + /// - Parameter path: A path string. + public func save(_ path: String) async throws { + try await saveInternal(path) + } + + private func saveInternal(_ path: String?) async throws { + var write = WriteOperation() + let plan = await self.df.getPlan() as! Plan + write.input = plan.root + write.mode = self.saveMode.toSaveMode + if let path = path { + write.path = path + } + + // Cannot both be set + // require(!(builder.hasPath && builder.hasTable)) + + if let source = self.source { + write.source = source + } + if let sortColumnNames = self.sortColumnNames { + write.sortColumnNames = sortColumnNames + } + if let partitioningColumns = self.partitioningColumns { + write.partitioningColumns = partitioningColumns + } + if let clusteringColumns = self.clusteringColumns { + write.clusteringColumns = clusteringColumns + } + if let numBuckets = self.numBuckets { + var bucketBy = WriteOperation.BucketBy() + bucketBy.numBuckets = numBuckets + if let bucketColumnNames = self.bucketColumnNames { + bucketBy.bucketColumnNames = bucketColumnNames + } + write.bucketBy = bucketBy + } + + for option in self.extraOptions { + write.options[option.key] = option.value + } + + var command = Spark_Connect_Command() + command.writeOperation = write + + _ = try await df.spark.client.execute(df.spark.sessionID, command) + } + + /// Saves the content of the `DataFrame` in CSV format at the specified path. + /// - Parameter path: A path string + /// - Returns: A `DataFrame`. + public func csv(_ path: String) async throws { + self.source = "csv" + return try await save(path) + } + + /// Saves the content of the `DataFrame` in JSON format at the specified path. + /// - Parameter path: A path string + /// - Returns: A `DataFrame`. + public func json(_ path: String) async throws { + self.source = "json" + return try await save(path) + } + + /// Saves the content of the `DataFrame` in ORC format at the specified path. + /// - Parameter path: A path string + /// - Returns: A `DataFrame`. + public func orc(_ path: String) async throws { + self.source = "orc" + return try await save(path) + } + + /// Saves the content of the `DataFrame` in Parquet format at the specified path. + /// - Parameter path: A path string + public func parquet(_ path: String) async throws { + self.source = "parquet" + return try await save(path) + } + + /// Saves the content of the `DataFrame` in a text file at the specified path. + /// The DataFrame must have only one column that is of string type. + /// Each row becomes a new line in the output file. + /// + /// - Parameter path: A path string + public func text(_ path: String) async throws { + self.source = "text" + return try await save(path) + } +} diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 7fdbaee..0594ea2 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -69,6 +69,17 @@ extension String { } return mode } + + var toSaveMode: SaveMode { + return switch self.lowercased() { + case "append": SaveMode.append + case "overwrite": SaveMode.overwrite + case "error": SaveMode.errorIfExists + case "errorIfExists": SaveMode.errorIfExists + case "ignore": SaveMode.ignore + default: SaveMode.errorIfExists + } + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 99e0c11..9dfde0d 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -361,4 +361,30 @@ public actor SparkConnectClient { plan.opType = .root(relation) return plan } + + var result: [ExecutePlanResponse] = [] + private func addResponse(_ response: ExecutePlanResponse) { + self.result.append(response) + } + + func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] { + self.result.removeAll() + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + let service = SparkConnectService.Client(wrapping: client) + var plan = Plan() + plan.opType = .command(command) + try await service.executePlan(getExecutePlanRequest(sessionID, plan)) { + response in + for try await m in response.messages { + await self.addResponse(m) + } + } + } + return result + } } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index ff8fd11..f82c4f5 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -18,11 +18,13 @@ typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse +typealias Command = Spark_Connect_Command typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataSource = Spark_Connect_Read.DataSource typealias DataType = Spark_Connect_DataType typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest +typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter @@ -35,9 +37,11 @@ typealias Project = Spark_Connect_Project typealias Range = Spark_Connect_Range typealias Read = Spark_Connect_Read typealias Relation = Spark_Connect_Relation +typealias SaveMode = Spark_Connect_WriteOperation.SaveMode typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort typealias StructType = Spark_Connect_DataType.Struct typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute +typealias WriteOperation = Spark_Connect_WriteOperation typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift new file mode 100644 index 0000000..26d9038 --- /dev/null +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -0,0 +1,117 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import Testing + +@testable import SparkConnect + +/// A test suite for `DataFrameWriter` +struct DataFrameWriterTests { + + @Test + func csv() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.csv(tmpDir) + #expect(try await spark.read.csv(tmpDir).count() == 2025) + await spark.stop() + } + + @Test + func json() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.json(tmpDir) + #expect(try await spark.read.json(tmpDir).count() == 2025) + await spark.stop() + } + + @Test + func orc() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).count() == 2025) + await spark.stop() + } + + @Test + func parquet() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.parquet(tmpDir) + #expect(try await spark.read.parquet(tmpDir).count() == 2025) + await spark.stop() + } + + @Test + func pathAlreadyExist() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.csv(tmpDir) + try await #require(throws: Error.self) { + try await spark.range(2025).write.csv(tmpDir) + } + await spark.stop() + } + + @Test + func overwrite() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(2025).write.csv(tmpDir) + try await spark.range(2025).write.mode("overwrite").csv(tmpDir) + await spark.stop() + } + + @Test + func save() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + for format in ["csv", "json", "orc", "parquet"] { + try await spark.range(2025).write.format(format).mode("overwrite").save(tmpDir) + #expect(try await spark.read.format(format).load(tmpDir).count() == 2025) + } + await spark.stop() + } + + @Test + func partitionBy() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("SELECT 1 col1, 2 col2").write.partitionBy("col2").csv(tmpDir) + #expect(try await spark.read.csv("\(tmpDir)/col2=2").count() == 1) + await spark.stop() + } + + @Test + func sortByBucketBy() async throws { + let tmpDir = "/tmp/" + UUID().uuidString + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT 1 col1, 2 col2") + try await #require(throws: Error.self) { + try await df.write.sortBy("col2").csv(tmpDir) + } + try await #require(throws: Error.self) { + try await df.write.sortBy("col2").bucketBy(numBuckets: 3, "col2").csv(tmpDir) + } + await spark.stop() + } +} From 2829f3f2c17ee5af574f9205a64028ae0a2d6b8c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 2 Apr 2025 22:45:54 +0900 Subject: [PATCH 2/2] Address comment --- Tests/SparkConnectTests/DataFrameWriterTests.swift | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift index 26d9038..3bb1142 100644 --- a/Tests/SparkConnectTests/DataFrameWriterTests.swift +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -112,6 +112,9 @@ struct DataFrameWriterTests { try await #require(throws: Error.self) { try await df.write.sortBy("col2").bucketBy(numBuckets: 3, "col2").csv(tmpDir) } + try await #require(throws: Error.self) { + try await df.write.bucketBy(numBuckets: 3, "col2").csv(tmpDir) + } await spark.stop() } }