Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public actor DataFrame: Sendable {
self.plan = sqlText.toSparkConnectPlan
}

public func getPlan() -> Sendable {
return self.plan
}

/// Set the schema. This is used to store the analized schema response from `Spark Connect` server.
/// - Parameter schema: <#schema description#>
private func setSchema(_ schema: DataType) {
Expand Down Expand Up @@ -382,4 +386,11 @@ public actor DataFrame: Sendable {
print(response.treeString.treeString)
}
}

/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
var write: DataFrameWriter {
get {
return DataFrameWriter(df: self)
}
}
}
199 changes: 199 additions & 0 deletions Sources/SparkConnect/DataFrameWriter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
import Atomics
import Foundation
import GRPCCore
import GRPCNIOTransportHTTP2
import GRPCProtobuf
import NIOCore
import SwiftyTextTable
import Synchronization

/// An interface used to write a `DataFrame` to external storage systems
/// (e.g. file systems, key-value stores, etc). Use `DataFrame.write` to access this.
public actor DataFrameWriter: Sendable {
var source: String? = nil

var saveMode: String = "default"

// TODO: Case-insensitive Map
var extraOptions: [String: String] = [:]

var partitioningColumns: [String]? = nil

var bucketColumnNames: [String]? = nil

var numBuckets: Int32? = nil

var sortColumnNames: [String]? = nil

var clusteringColumns: [String]? = nil

let df: DataFrame

init(df: DataFrame) {
self.df = df
}

/// Specifies the output data source format.
/// - Parameter source: A string.
/// - Returns: A `DataFrameReader`.
public func format(_ source: String) -> DataFrameWriter {
self.source = source
return self
}

/// Specifies the behavior when data or table already exists. Options include:
/// `overwrite`, `append`, `ignore`, `error` or `errorifexists` (default).
///
/// - Parameter saveMode: A string for save mode.
/// - Returns: A `DataFrameWriter`
public func mode(_ saveMode: String) -> DataFrameWriter {
self.saveMode = saveMode
return self
}

/// Adds an output option for the underlying data source.
/// - Parameters:
/// - key: A key string.
/// - value: A value string.
/// - Returns: A `DataFrameWriter`.
public func option(_ key: String, _ value: String) -> DataFrameWriter {
self.extraOptions[key] = value
return self
}

public func partitionBy(_ columns: String...) -> DataFrameWriter {
self.partitioningColumns = columns
return self
}

public func bucketBy(numBuckets: Int32, _ columnNames: String...) -> DataFrameWriter {
self.numBuckets = numBuckets
self.bucketColumnNames = columnNames
return self
}

public func sortBy(_ columnNames: String...) -> DataFrameWriter {
self.sortColumnNames = columnNames
return self
}

public func clusterBy(_ columnNames: String...) -> DataFrameWriter {
self.clusteringColumns = columnNames
return self
}

/// Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
/// key-value stores).
public func save() async throws {
return try await saveInternal(nil)
}

/// Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
/// local or distributed file system).
/// - Parameter path: A path string.
public func save(_ path: String) async throws {
try await saveInternal(path)
}

private func saveInternal(_ path: String?) async throws {
var write = WriteOperation()
let plan = await self.df.getPlan() as! Plan
write.input = plan.root
write.mode = self.saveMode.toSaveMode
if let path = path {
write.path = path
}

// Cannot both be set
// require(!(builder.hasPath && builder.hasTable))

if let source = self.source {
write.source = source
}
if let sortColumnNames = self.sortColumnNames {
write.sortColumnNames = sortColumnNames
}
if let partitioningColumns = self.partitioningColumns {
write.partitioningColumns = partitioningColumns
}
if let clusteringColumns = self.clusteringColumns {
write.clusteringColumns = clusteringColumns
}
if let numBuckets = self.numBuckets {
var bucketBy = WriteOperation.BucketBy()
bucketBy.numBuckets = numBuckets
if let bucketColumnNames = self.bucketColumnNames {
bucketBy.bucketColumnNames = bucketColumnNames
}
write.bucketBy = bucketBy
}

for option in self.extraOptions {
write.options[option.key] = option.value
}

var command = Spark_Connect_Command()
command.writeOperation = write

_ = try await df.spark.client.execute(df.spark.sessionID, command)
}

/// Saves the content of the `DataFrame` in CSV format at the specified path.
/// - Parameter path: A path string
/// - Returns: A `DataFrame`.
public func csv(_ path: String) async throws {
self.source = "csv"
return try await save(path)
}

/// Saves the content of the `DataFrame` in JSON format at the specified path.
/// - Parameter path: A path string
/// - Returns: A `DataFrame`.
public func json(_ path: String) async throws {
self.source = "json"
return try await save(path)
}

/// Saves the content of the `DataFrame` in ORC format at the specified path.
/// - Parameter path: A path string
/// - Returns: A `DataFrame`.
public func orc(_ path: String) async throws {
self.source = "orc"
return try await save(path)
}

/// Saves the content of the `DataFrame` in Parquet format at the specified path.
/// - Parameter path: A path string
public func parquet(_ path: String) async throws {
self.source = "parquet"
return try await save(path)
}

/// Saves the content of the `DataFrame` in a text file at the specified path.
/// The DataFrame must have only one column that is of string type.
/// Each row becomes a new line in the output file.
///
/// - Parameter path: A path string
public func text(_ path: String) async throws {
self.source = "text"
return try await save(path)
}
}
11 changes: 11 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ extension String {
}
return mode
}

var toSaveMode: SaveMode {
return switch self.lowercased() {
case "append": SaveMode.append
case "overwrite": SaveMode.overwrite
case "error": SaveMode.errorIfExists
case "errorIfExists": SaveMode.errorIfExists
case "ignore": SaveMode.ignore
default: SaveMode.errorIfExists
}
}
}

extension [String: String] {
Expand Down
26 changes: 26 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,30 @@ public actor SparkConnectClient {
plan.opType = .root(relation)
return plan
}

var result: [ExecutePlanResponse] = []
private func addResponse(_ response: ExecutePlanResponse) {
self.result.append(response)
}

func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] {
self.result.removeAll()
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
let service = SparkConnectService.Client(wrapping: client)
var plan = Plan()
plan.opType = .command(command)
try await service.executePlan(getExecutePlanRequest(sessionID, plan)) {
response in
for try await m in response.messages {
await self.addResponse(m)
}
}
}
return result
}
}
4 changes: 4 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
typealias Command = Spark_Connect_Command
typealias ConfigRequest = Spark_Connect_ConfigRequest
typealias DataSource = Spark_Connect_Read.DataSource
typealias DataType = Spark_Connect_DataType
typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
Expand All @@ -35,9 +37,11 @@ typealias Project = Spark_Connect_Project
typealias Range = Spark_Connect_Range
typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StructType = Spark_Connect_DataType.Struct
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
typealias WriteOperation = Spark_Connect_WriteOperation
typealias YearMonthInterval = Spark_Connect_DataType.YearMonthInterval
Loading
Loading