Skip to content

Commit c405943

Browse files
committed
[SPARK-51472] Add gRPC SparkConnectClient actor
### What changes were proposed in this pull request? This PR aims to add a Swift `SparkConnectClient` actor encapsulating `gRPC` connections which is similar to other language clients. - Swift (this PR) ```swift public actor SparkConnectClient { ``` - [Scala](https://github.com/apache/spark/blob/master/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala): ```scala private[sql] class SparkConnectClient( ``` - [Python](https://github.com/apache/spark/blob/master/python/pyspark/sql/connect/client/core.py#L597) ```python class SparkConnectClient(object): ``` This is a part of the following. - #1 ### Why are the changes needed? To use `gRPC` in the upper `SparkSession` and `DataFrame` layers easily. ### Does this PR introduce _any_ user-facing change? No, this is not released. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #7 from dongjoon-hyun/SPARK-51472. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7e2826e commit c405943

File tree

5 files changed

+372
-0
lines changed

5 files changed

+372
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
22+
extension String {
23+
/// Get a `Plan` instance from a string.
24+
var toSparkConnectPlan: Plan {
25+
var sql = Spark_Connect_SQL()
26+
sql.query = self
27+
var relation = Relation()
28+
relation.sql = sql
29+
var plan = Plan()
30+
plan.opType = Plan.OneOf_OpType.root(relation)
31+
return plan
32+
}
33+
34+
/// Get a `UserContext` instance from a string.
35+
var toUserContext: UserContext {
36+
var context = UserContext()
37+
context.userID = self
38+
context.userName = self
39+
return context
40+
}
41+
42+
/// Get a `KeyValue` instance by using a string as the key.
43+
var toKeyValue: KeyValue {
44+
var keyValue = KeyValue()
45+
keyValue.key = self
46+
return keyValue
47+
}
48+
}
49+
50+
extension [String: String] {
51+
/// Get an array of `KeyValue` from `[String: String]`.
52+
var toSparkConnectKeyValue: [KeyValue] {
53+
var array = [KeyValue]()
54+
for keyValue in self {
55+
var kv = KeyValue()
56+
kv.key = keyValue.key
57+
kv.value = keyValue.value
58+
array.append(kv)
59+
}
60+
return array
61+
}
62+
}
63+
64+
extension Data {
65+
/// Get an `Int32` value from unsafe 4 bytes.
66+
var int32: Int32 { withUnsafeBytes({ $0.load(as: Int32.self) }) }
67+
}
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
19+
import Foundation
20+
import GRPCCore
21+
import GRPCNIOTransportHTTP2
22+
import GRPCProtobuf
23+
import Synchronization
24+
25+
/// Conceptually the remote spark session that communicates with the server
26+
public actor SparkConnectClient {
27+
let clientType: String = "swift"
28+
let url: URL
29+
let host: String
30+
let port: Int
31+
let userContext: UserContext
32+
var sessionID: String? = nil
33+
34+
/// Create a client to use GRPCClient.
35+
/// - Parameters:
36+
/// - remote: A string to connect `Spark Connect` server.
37+
/// - user: A string for the user ID of this connection.
38+
init(remote: String, user: String) {
39+
self.url = URL(string: remote)!
40+
self.host = url.host() ?? "localhost"
41+
self.port = self.url.port ?? 15002
42+
self.userContext = user.toUserContext
43+
}
44+
45+
/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
46+
func stop() {
47+
}
48+
49+
/// Connect to the `Spark Connect` server with the given session ID string.
50+
/// As a test connection, this sends the server `SparkVersion` request.
51+
/// - Parameter sessionID: A string for the session ID.
52+
/// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion`
53+
func connect(_ sessionID: String) async throws -> AnalyzePlanResponse {
54+
try await withGRPCClient(
55+
transport: .http2NIOPosix(
56+
target: .dns(host: self.host, port: self.port),
57+
transportSecurity: .plaintext
58+
)
59+
) { client in
60+
// To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception.
61+
if UUID(uuidString: sessionID) == nil {
62+
throw SparkConnectError.InvalidSessionIDException
63+
}
64+
65+
self.sessionID = sessionID
66+
let service = SparkConnectService.Client(wrapping: client)
67+
let version = AnalyzePlanRequest.SparkVersion()
68+
var request = AnalyzePlanRequest()
69+
request.clientType = clientType
70+
request.userContext = userContext
71+
request.sessionID = self.sessionID!
72+
request.analyze = .sparkVersion(version)
73+
let response = try await service.analyzePlan(request)
74+
return response
75+
}
76+
}
77+
78+
/// Create a ``ConfigRequest`` instance for `Set` operation.
79+
/// - Parameter map: A map of key-value string pairs.
80+
/// - Returns: A ``ConfigRequest`` instance.
81+
func getConfigRequestSet(map: [String: String]) -> ConfigRequest {
82+
var request = ConfigRequest()
83+
request.operation = ConfigRequest.Operation()
84+
var set = ConfigRequest.Set()
85+
set.pairs = map.toSparkConnectKeyValue
86+
request.operation.opType = .set(set)
87+
return request
88+
}
89+
90+
/// Request the server to set a map of configurations for this session.
91+
/// - Parameter map: A map of key-value pairs to set.
92+
/// - Returns: Always return true.
93+
func setConf(map: [String: String]) async throws -> Bool {
94+
try await withGRPCClient(
95+
transport: .http2NIOPosix(
96+
target: .dns(host: self.host, port: self.port),
97+
transportSecurity: .plaintext
98+
)
99+
) { client in
100+
let service = SparkConnectService.Client(wrapping: client)
101+
var request = getConfigRequestSet(map: map)
102+
request.clientType = clientType
103+
request.userContext = userContext
104+
request.sessionID = self.sessionID!
105+
let _ = try await service.config(request)
106+
return true
107+
}
108+
}
109+
110+
/// Create a ``ConfigRequest`` instance for `Get` operation.
111+
/// - Parameter keys: An array of keys to get.
112+
/// - Returns: A `ConfigRequest` instance.
113+
func getConfigRequestGet(keys: [String]) -> ConfigRequest {
114+
var request = ConfigRequest()
115+
request.operation = ConfigRequest.Operation()
116+
var get = ConfigRequest.Get()
117+
get.keys = keys
118+
request.operation.opType = .get(get)
119+
return request
120+
}
121+
122+
/// Request the server to get a value of the given key.
123+
/// - Parameter key: A string for key to look up.
124+
/// - Returns: A string for the value of the key.
125+
func getConf(_ key: String) async throws -> String {
126+
try await withGRPCClient(
127+
transport: .http2NIOPosix(
128+
target: .dns(host: self.host, port: self.port),
129+
transportSecurity: .plaintext
130+
)
131+
) { client in
132+
let service = SparkConnectService.Client(wrapping: client)
133+
var request = getConfigRequestGet(keys: [key])
134+
request.clientType = clientType
135+
request.userContext = userContext
136+
request.sessionID = self.sessionID!
137+
let response = try await service.config(request)
138+
return response.pairs[0].value
139+
}
140+
}
141+
142+
/// Create a ``ConfigRequest`` for `GetAll` operation.
143+
/// - Returns: A `ConfigRequest` instance.
144+
func getConfigRequestGetAll() -> ConfigRequest {
145+
var request = ConfigRequest()
146+
request.operation = ConfigRequest.Operation()
147+
let getAll = ConfigRequest.GetAll()
148+
request.operation.opType = .getAll(getAll)
149+
return request
150+
}
151+
152+
/// Request the server to get all configurations.
153+
/// - Returns: A map of key-value pairs.
154+
func getConfAll() async throws -> [String: String] {
155+
try await withGRPCClient(
156+
transport: .http2NIOPosix(
157+
target: .dns(host: self.host, port: self.port),
158+
transportSecurity: .plaintext
159+
)
160+
) { client in
161+
let service = SparkConnectService.Client(wrapping: client)
162+
var request = getConfigRequestGetAll()
163+
request.clientType = clientType
164+
request.userContext = userContext
165+
request.sessionID = self.sessionID!
166+
let response = try await service.config(request)
167+
var map = [String: String]()
168+
for pair in response.pairs {
169+
map[pair.key] = pair.value
170+
}
171+
return map
172+
}
173+
}
174+
175+
/// Create a `Plan` instance for `Range` relation.
176+
/// - Parameters:
177+
/// - start: A start of the range.
178+
/// - end: A end (exclusive) of the range.
179+
/// - step: A step value for the range from `start` to `end`.
180+
/// - Returns: A `Plan` instance.
181+
func getPlanRange(_ start: Int64, _ end: Int64, _ step: Int64) -> Plan {
182+
var range = Range()
183+
range.start = start
184+
range.end = end
185+
range.step = step
186+
var relation = Relation()
187+
relation.range = range
188+
var plan = Plan()
189+
plan.opType = .root(relation)
190+
return plan
191+
}
192+
193+
/// Create a ``ExecutePlanRequest`` instance with the given plan.
194+
/// The operation ID is created by UUID.
195+
/// - Parameters:
196+
/// - plan: A plan to execute.
197+
/// - Returns: An ``ExecutePlanRequest`` instance.
198+
func getExecutePlanRequest(_ sessionID: String, _ plan: Plan) async
199+
-> ExecutePlanRequest
200+
{
201+
var request = ExecutePlanRequest()
202+
request.clientType = clientType
203+
request.userContext = userContext
204+
request.sessionID = self.sessionID!
205+
request.operationID = UUID().uuidString
206+
request.plan = plan
207+
return request
208+
}
209+
210+
/// Create a ``AnalyzePlanRequest`` instance with the given plan.
211+
/// - Parameters:
212+
/// - plan: A plan to analyze.
213+
/// - Returns: An ``AnalyzePlanRequest`` instance
214+
func getAnalyzePlanRequest(_ sessionID: String, _ plan: Plan) async
215+
-> AnalyzePlanRequest
216+
{
217+
var request = AnalyzePlanRequest()
218+
request.clientType = clientType
219+
request.userContext = userContext
220+
request.sessionID = self.sessionID!
221+
var schema = AnalyzePlanRequest.Schema()
222+
schema.plan = plan
223+
request.analyze = .schema(schema)
224+
return request
225+
}
226+
}

Sources/SparkConnect/SparkConnectError.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
/// A enum for ``SparkConnect`` package errors
2121
enum SparkConnectError: Error {
2222
case UnsupportedOperationException
23+
case InvalidSessionIDException
2324
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
19+
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
20+
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
21+
typealias ConfigRequest = Spark_Connect_ConfigRequest
22+
typealias DataType = Spark_Connect_DataType
23+
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
24+
typealias Plan = Spark_Connect_Plan
25+
typealias KeyValue = Spark_Connect_KeyValue
26+
typealias Range = Spark_Connect_Range
27+
typealias Relation = Spark_Connect_Relation
28+
typealias SparkConnectService = Spark_Connect_SparkConnectService
29+
typealias UserContext = Spark_Connect_UserContext
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import Testing
22+
23+
@testable import SparkConnect
24+
25+
/// A test suite for `SparkConnectClient`
26+
@Suite(.serialized)
27+
struct SparkConnectClientTests {
28+
@Test
29+
func createAndStop() async throws {
30+
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
31+
await client.stop()
32+
}
33+
34+
@Test
35+
func connectWithInvalidUUID() async throws {
36+
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
37+
try await #require(throws: SparkConnectError.InvalidSessionIDException) {
38+
let _ = try await client.connect("not-a-uuid-format")
39+
}
40+
await client.stop()
41+
}
42+
43+
@Test
44+
func connect() async throws {
45+
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
46+
let _ = try await client.connect(UUID().uuidString)
47+
await client.stop()
48+
}
49+
}

0 commit comments

Comments
 (0)