Skip to content

Commit fc9ceff

Browse files
committed
[SPARK-51857] Support token/userId/userAgent parameters in SparkConnectClient
1 parent 6241ca4 commit fc9ceff

File tree

4 files changed

+56
-24
lines changed

4 files changed

+56
-24
lines changed

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import GRPCProtobuf
2323

2424
/// Conceptually the remote spark session that communicates with the server
2525
public actor SparkConnectClient {
26-
let clientType: String = "swift"
26+
var clientType: String = "swift"
2727
let url: URL
2828
let host: String
2929
let port: Int
@@ -36,16 +36,36 @@ public actor SparkConnectClient {
3636
/// Create a client to use GRPCClient.
3737
/// - Parameters:
3838
/// - remote: A string to connect `Spark Connect` server.
39-
/// - user: A string for the user ID of this connection.
40-
init(remote: String, user: String, token: String? = nil) {
39+
init(remote: String) {
4140
self.url = URL(string: remote)!
4241
self.host = url.host() ?? "localhost"
4342
self.port = self.url.port ?? 15002
43+
var token: String? = nil
44+
let processInfo = ProcessInfo.processInfo
45+
#if os(macOS) || os(Linux)
46+
var userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
47+
#else
48+
var userName = processInfo.environment["SPARK_USER"] ?? ""
49+
#endif
50+
for param in self.url.path.split(separator: ";").dropFirst().filter({ !$0.isEmpty }) {
51+
let kv = param.split(separator: "=")
52+
switch String(kv[0]) {
53+
case URIParams.PARAM_USER_AGENT:
54+
clientType = String(kv[1])
55+
case URIParams.PARAM_TOKEN:
56+
token = String(kv[1])
57+
case URIParams.PARAM_USER_ID:
58+
userName = String(kv[1])
59+
default:
60+
// Print warning and ignore
61+
print("Unknown parameter: \(param)")
62+
}
63+
}
4464
self.token = token ?? ProcessInfo.processInfo.environment["SPARK_CONNECT_AUTHENTICATE_TOKEN"]
4565
if let token = self.token {
4666
self.intercepters.append(BearerTokenInterceptor(token: token))
4767
}
48-
self.userContext = user.toUserContext
68+
self.userContext = userName.toUserContext
4969
}
5070

5171
/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
@@ -574,4 +594,13 @@ public actor SparkConnectClient {
574594
return OneOf_Analyze.isStreaming(isStreaming)
575595
})
576596
}
597+
598+
private enum URIParams {
599+
static let PARAM_USER_ID = "userId"
600+
static let PARAM_USER_AGENT = "userAgent"
601+
static let PARAM_TOKEN = "token"
602+
static let PARAM_USE_SSL = "useSsl"
603+
static let PARAM_SESSION_ID = "sessionId"
604+
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpcMaxMessageSize"
605+
}
577606
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
// under the License.
1818
//
1919

20-
import Dispatch
2120
import Foundation
2221

2322
/// The entry point to programming Spark with ``DataFrame`` API.
@@ -39,15 +38,8 @@ public actor SparkSession {
3938
/// Create a session that uses the specified connection string and userID.
4039
/// - Parameters:
4140
/// - connection: a string in a patter, `sc://{host}:{port}`
42-
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
43-
init(_ connection: String, _ userID: String? = nil) {
44-
let processInfo = ProcessInfo.processInfo
45-
#if os(macOS) || os(Linux)
46-
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
47-
#else
48-
let userName = processInfo.environment["SPARK_USER"] ?? ""
49-
#endif
50-
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
41+
init(_ connection: String) {
42+
self.client = SparkConnectClient(remote: connection)
5143
self.conf = RuntimeConf(self.client)
5244
}
5345

Tests/SparkConnectTests/RuntimeConfTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import Testing
2727
struct RuntimeConfTests {
2828
@Test
2929
func get() async throws {
30-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
30+
let client = SparkConnectClient(remote: "sc://localhost")
3131
_ = try await client.connect(UUID().uuidString)
3232
let conf = RuntimeConf(client)
3333

@@ -42,7 +42,7 @@ struct RuntimeConfTests {
4242

4343
@Test
4444
func set() async throws {
45-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
45+
let client = SparkConnectClient(remote: "sc://localhost")
4646
_ = try await client.connect(UUID().uuidString)
4747
let conf = RuntimeConf(client)
4848
try await conf.set("spark.test.key1", "value1")
@@ -52,7 +52,7 @@ struct RuntimeConfTests {
5252

5353
@Test
5454
func reset() async throws {
55-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
55+
let client = SparkConnectClient(remote: "sc://localhost")
5656
_ = try await client.connect(UUID().uuidString)
5757
let conf = RuntimeConf(client)
5858

@@ -73,7 +73,7 @@ struct RuntimeConfTests {
7373

7474
@Test
7575
func getAll() async throws {
76-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
76+
let client = SparkConnectClient(remote: "sc://localhost")
7777
_ = try await client.connect(UUID().uuidString)
7878
let conf = RuntimeConf(client)
7979
let map = try await conf.getAll()

Tests/SparkConnectTests/SparkConnectClientTests.swift

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,24 @@ import Testing
2727
struct SparkConnectClientTests {
2828
@Test
2929
func createAndStop() async throws {
30-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
30+
let client = SparkConnectClient(remote: "sc://localhost")
31+
await client.stop()
32+
}
33+
34+
@Test
35+
func parameters() async throws {
36+
let client = SparkConnectClient(remote: "sc://host1:123/;token=abcd;userId=test;userAgent=myagent")
37+
#expect(await client.token == "abcd")
38+
#expect(await client.userContext.userID == "test")
39+
#expect(await client.clientType == "myagent")
40+
#expect(await client.host == "host1")
41+
#expect(await client.port == 123)
3142
await client.stop()
3243
}
3344

3445
@Test
3546
func connectWithInvalidUUID() async throws {
36-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
47+
let client = SparkConnectClient(remote: "sc://localhost")
3748
try await #require(throws: SparkConnectError.InvalidSessionIDException) {
3849
let _ = try await client.connect("not-a-uuid-format")
3950
}
@@ -42,14 +53,14 @@ struct SparkConnectClientTests {
4253

4354
@Test
4455
func connect() async throws {
45-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
56+
let client = SparkConnectClient(remote: "sc://localhost")
4657
let _ = try await client.connect(UUID().uuidString)
4758
await client.stop()
4859
}
4960

5061
@Test
5162
func tags() async throws {
52-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
63+
let client = SparkConnectClient(remote: "sc://localhost")
5364
let _ = try await client.connect(UUID().uuidString)
5465
let plan = await client.getPlanRange(0, 1, 1)
5566

@@ -65,7 +76,7 @@ struct SparkConnectClientTests {
6576

6677
@Test
6778
func ddlParse() async throws {
68-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
79+
let client = SparkConnectClient(remote: "sc://localhost")
6980
let _ = try await client.connect(UUID().uuidString)
7081
#expect(try await client.ddlParse("a int").simpleString == "struct<a:int>")
7182
await client.stop()
@@ -74,7 +85,7 @@ struct SparkConnectClientTests {
7485
#if !os(Linux) // TODO: Enable this with the offical Spark 4 docker image
7586
@Test
7687
func jsonToDdl() async throws {
77-
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
88+
let client = SparkConnectClient(remote: "sc://localhost")
7889
let _ = try await client.connect(UUID().uuidString)
7990
let json =
8091
#"{"type":"struct","fields":[{"name":"id","type":"long","nullable":false,"metadata":{}}]}"#

0 commit comments

Comments
 (0)