Skip to content

Commit d5c5069

Browse files
committed
[SPARK-51483] Add SparkSession and DataFrame actors
### What changes were proposed in this pull request? This PR aims to add `SparkSession` and `DataFrame` actors. - `SparkSession.SparkContext` is defined as an empty `struct` just as a type. - `SparkSession.Builder` is defined to match with the builder pattern. ### Why are the changes needed? To allow users to start to use this library. After this PR, we can run the test against the real `Spark Connect` servers. ### Does this PR introduce _any_ user-facing change? No, this is not released yet. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #10 from dongjoon-hyun/SPARK-51483. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 18584e8 commit d5c5069

File tree

5 files changed

+569
-0
lines changed

5 files changed

+569
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Atomics
20+
import Foundation
21+
import GRPCCore
22+
import GRPCNIOTransportHTTP2
23+
import GRPCProtobuf
24+
import NIOCore
25+
import SwiftyTextTable
26+
import Synchronization
27+
28+
/// A DataFrame which supports only SQL queries
29+
public actor DataFrame: Sendable {
30+
var spark: SparkSession
31+
var plan: Plan
32+
var schema: DataType? = nil
33+
private var batches: [RecordBatch] = [RecordBatch]()
34+
35+
/// Create a new `DataFrame`instance with the given Spark session and plan.
36+
/// - Parameters:
37+
/// - spark: A ``SparkSession`` instance to use.
38+
/// - plan: A plan to execute.
39+
init(spark: SparkSession, plan: Plan) async throws {
40+
self.spark = spark
41+
self.plan = plan
42+
}
43+
44+
/// Create a new `DataFrame` instance with the given SparkSession and a SQL statement.
45+
/// - Parameters:
46+
/// - spark: A `SparkSession` instance to use.
47+
/// - sqlText: A SQL statement.
48+
init(spark: SparkSession, sqlText: String) async throws {
49+
self.spark = spark
50+
self.plan = sqlText.toSparkConnectPlan
51+
}
52+
53+
/// Set the schema. This is used to store the analized schema response from `Spark Connect` server.
54+
/// - Parameter schema: <#schema description#>
55+
private func setSchema(_ schema: DataType) {
56+
self.schema = schema
57+
}
58+
59+
/// Add `Apache Arrow`'s `RecordBatch`s to the internal array.
60+
/// - Parameter batches: An array of ``RecordBatch``.
61+
private func addBathes(_ batches: [RecordBatch]) {
62+
self.batches.append(contentsOf: batches)
63+
}
64+
65+
/// A method to access the underlying Spark's `RDD`.
66+
/// In `Spark Connect`, this feature is not allowed by design.
67+
public func rdd() throws {
68+
// SQLSTATE: 0A000
69+
// [UNSUPPORTED_CONNECT_FEATURE.RDD]
70+
// Feature is not supported in Spark Connect: Resilient Distributed Datasets (RDDs).
71+
throw SparkConnectError.UnsupportedOperationException
72+
}
73+
74+
/// Return a `JSON` string of data type because we cannot expose the internal type ``DataType``.
75+
/// - Returns: a `JSON` string.
76+
public func schema() async throws -> String {
77+
var dataType: String? = nil
78+
79+
try await withGRPCClient(
80+
transport: .http2NIOPosix(
81+
target: .dns(host: spark.client.host, port: spark.client.port),
82+
transportSecurity: .plaintext
83+
)
84+
) { client in
85+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
86+
let response = try await service.analyzePlan(
87+
spark.client.getAnalyzePlanRequest(spark.sessionID, plan))
88+
dataType = try response.schema.schema.jsonString()
89+
}
90+
return dataType!
91+
}
92+
93+
/// Return the total number of rows.
94+
/// - Returns: a `Int64` value.
95+
public func count() async throws -> Int64 {
96+
let counter = Atomic(Int64(0))
97+
98+
try await withGRPCClient(
99+
transport: .http2NIOPosix(
100+
target: .dns(host: spark.client.host, port: spark.client.port),
101+
transportSecurity: .plaintext
102+
)
103+
) { client in
104+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
105+
try await service.executePlan(spark.client.getExecutePlanRequest(spark.sessionID, plan)) {
106+
response in
107+
for try await m in response.messages {
108+
counter.add(m.arrowBatch.rowCount, ordering: .relaxed)
109+
}
110+
}
111+
}
112+
return counter.load(ordering: .relaxed)
113+
}
114+
115+
/// Execute the plan and try to fill `schema` and `batches`.
116+
private func execute() async throws {
117+
try await withGRPCClient(
118+
transport: .http2NIOPosix(
119+
target: .dns(host: spark.client.host, port: spark.client.port),
120+
transportSecurity: .plaintext
121+
)
122+
) { client in
123+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
124+
try await service.executePlan(spark.client.getExecutePlanRequest(spark.sessionID, plan)) {
125+
response in
126+
for try await m in response.messages {
127+
if m.hasSchema {
128+
// The original schema should arrive before ArrowBatches
129+
await self.setSchema(m.schema)
130+
}
131+
let ipcStreamBytes = m.arrowBatch.data
132+
if !ipcStreamBytes.isEmpty && m.arrowBatch.rowCount > 0 {
133+
let IPC_CONTINUATION_TOKEN = Int32(-1)
134+
// Schema
135+
assert(ipcStreamBytes[0..<4].int32 == IPC_CONTINUATION_TOKEN)
136+
let schemaSize = Int64(ipcStreamBytes[4..<8].int32)
137+
let schema = Data(ipcStreamBytes[8..<(8 + schemaSize)])
138+
139+
// Arrow IPC Data
140+
assert(
141+
ipcStreamBytes[(8 + schemaSize)..<(8 + schemaSize + 4)].int32
142+
== IPC_CONTINUATION_TOKEN)
143+
var pos: Int64 = 8 + schemaSize + 4
144+
let dataHeaderSize = Int64(ipcStreamBytes[pos..<(pos + 4)].int32)
145+
pos += 4
146+
let dataHeader = Data(ipcStreamBytes[pos..<(pos + dataHeaderSize)])
147+
pos += dataHeaderSize
148+
let dataBodySize = Int64(ipcStreamBytes.count) - pos - 8
149+
let dataBody = Data(ipcStreamBytes[pos..<(pos + dataBodySize)])
150+
151+
// Read ArrowBatches
152+
let reader = ArrowReader()
153+
let arrowResult = ArrowReader.makeArrowReaderResult()
154+
_ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult)
155+
_ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult)
156+
await self.addBathes(arrowResult.batches)
157+
}
158+
}
159+
}
160+
}
161+
}
162+
163+
/// This is designed not to support this feature in order to simplify the Swift client.
164+
public func collect() async throws {
165+
throw SparkConnectError.UnsupportedOperationException
166+
}
167+
168+
/// Execute the plan and show the result.
169+
public func show() async throws {
170+
try await execute()
171+
172+
if let schema = self.schema {
173+
var columns: [TextTableColumn] = []
174+
for f in schema.struct.fields {
175+
columns.append(TextTableColumn(header: f.name))
176+
}
177+
var table = TextTable(columns: columns)
178+
for batch in self.batches {
179+
for i in 0..<batch.length {
180+
var values: [String] = []
181+
for column in batch.columns {
182+
let str = column.array as! AsString
183+
if column.data.isNull(i) {
184+
values.append("NULL")
185+
} else {
186+
values.append(str.asString(i))
187+
}
188+
}
189+
table.addRow(values: values)
190+
}
191+
}
192+
print(table.render())
193+
}
194+
}
195+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import GRPCCore
22+
import GRPCNIOTransportHTTP2
23+
import GRPCProtobuf
24+
import Synchronization
25+
26+
/// The entry point to programming Spark with ``DataFrame`` API.
27+
///
28+
/// Use the builder to get a session.
29+
///
30+
/// ```swift
31+
/// let spark = try await SparkSession.builder.getOrCreate()
32+
/// ```
33+
public actor SparkSession {
34+
35+
public static let builder: Builder = Builder()
36+
37+
let client: SparkConnectClient
38+
39+
/// Runtime configuration interface for Spark.
40+
public let conf: RuntimeConf
41+
42+
/// Create a session that uses the specified connection string and userID.
43+
/// - Parameters:
44+
/// - connection: a string in a patter, `sc://{host}:{port}`
45+
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
46+
init(_ connection: String, _ userID: String? = nil) {
47+
let processInfo = ProcessInfo.processInfo
48+
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
49+
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
50+
self.conf = RuntimeConf(self.client)
51+
}
52+
53+
/// The Spark version of Spark Connect Servier. This is supposed to be overwritten during establishing connections.
54+
public var version: String = ""
55+
56+
func setVersion(_ version: String) {
57+
self.version = version
58+
}
59+
60+
/// A unique session ID for this session from client.
61+
var sessionID: String = UUID().uuidString
62+
63+
/// Get the current session ID
64+
/// - Returns: the current session ID
65+
func getSessionID() -> String {
66+
sessionID
67+
}
68+
69+
/// A server-side generated session ID. This is supposed to be overwritten during establishing connections.
70+
var serverSideSessionID: String = ""
71+
72+
/// A variable for ``SparkContext``. This is designed to throw exceptions by Apache Spark.
73+
var sparkContext: SparkContext {
74+
get throws {
75+
// SQLSTATE: 0A000
76+
// [UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT]
77+
// Feature is not supported in Spark Connect: Access to the SparkContext.
78+
throw SparkConnectError.UnsupportedOperationException
79+
}
80+
}
81+
82+
/// Stop the current client.
83+
public func stop() async {
84+
await client.stop()
85+
}
86+
87+
/// Create a ``DataFrame`` with a single ``Int64`` column name `id`, containing elements in a
88+
/// range from 0 to `end` (exclusive) with step value 1.
89+
///
90+
/// - Parameter end: A value for the end of range.
91+
/// - Returns: A ``DataFrame`` instance.
92+
public func range(_ end: Int64) async throws -> DataFrame {
93+
return try await range(0, end)
94+
}
95+
96+
/// Create a ``DataFrame`` with a single ``Int64`` column named `id`, containing elements in a
97+
/// range from `start` to `end` (exclusive) with a step value (default: 1).
98+
///
99+
/// - Parameters:
100+
/// - start: A value for the start of range.
101+
/// - end: A value for the end of range.
102+
/// - step: A value for the step.
103+
/// - Returns: A ``DataFrame`` instance.
104+
public func range(_ start: Int64, _ end: Int64, _ step: Int64 = 1) async throws -> DataFrame {
105+
return try await DataFrame(spark: self, plan: client.getPlanRange(start, end, step))
106+
}
107+
108+
/// Create a ``DataFrame`` for the given SQL statement.
109+
/// - Parameter sqlText: A SQL string.
110+
/// - Returns: A ``DataFrame`` instance.
111+
public func sql(_ sqlText: String) async throws -> DataFrame {
112+
return try await DataFrame(spark: self, sqlText: sqlText)
113+
}
114+
115+
/// This is defined as the return type of `SparkSession.sparkContext` method.
116+
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
117+
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.
118+
struct SparkContext {
119+
}
120+
121+
/// A builder to create ``SparkSession``
122+
public actor Builder {
123+
var sparkConf: [String: String] = [:]
124+
125+
/// Set a new configuration.
126+
/// - Parameters:
127+
/// - key: A string for the configuration key.
128+
/// - value: A string for the configuration value.
129+
/// - Returns: self
130+
public func config(_ key: String, _ value: String) -> Builder {
131+
sparkConf[key] = value
132+
return self
133+
}
134+
135+
/// Remove all stored configurations.
136+
/// - Returns: self
137+
func clear() -> Builder {
138+
sparkConf.removeAll()
139+
return self
140+
}
141+
142+
/// Set a url for remote connection.
143+
/// - Parameter url: A connection string in a pattern, `sc://{host}:{post}`.
144+
/// - Returns: self
145+
public func remote(_ url: String) -> Builder {
146+
return config("spark.remote", url)
147+
}
148+
149+
/// Set `appName` of this session.
150+
/// - Parameter name: A string for application name
151+
/// - Returns: self
152+
public func appName(_ name: String) -> Builder {
153+
return config("spark.app.name", name)
154+
}
155+
156+
/// Enable `Apache Hive` metastore support configuration.
157+
/// - Returns: self
158+
func enableHiveSupport() -> Builder {
159+
return config("spark.sql.catalogImplementation", "hive")
160+
}
161+
162+
/// Create a new ``SparkSession``. If `spark.remote` is not given, `sc://localhost:15002` is used.
163+
/// - Returns: A newly created `SparkSession`.
164+
func create() async throws -> SparkSession {
165+
let session = SparkSession(sparkConf["spark.remote"] ?? "sc://localhost:15002")
166+
let response = try await session.client.connect(session.sessionID)
167+
await session.setVersion(response.sparkVersion.version)
168+
let isSuccess = try await session.client.setConf(map: sparkConf)
169+
assert(isSuccess)
170+
return session
171+
}
172+
173+
/// Create a ``SparkSession`` from the given configurations.
174+
/// - Returns: A spark session.
175+
public func getOrCreate() async throws -> SparkSession {
176+
return try await create()
177+
}
178+
}
179+
}

0 commit comments

Comments
 (0)