Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import GRPCProtobuf

/// Conceptually the remote spark session that communicates with the server
public actor SparkConnectClient {
let clientType: String = "swift"
var clientType: String = "swift"
let url: URL
let host: String
let port: Int
Expand All @@ -36,16 +36,36 @@ public actor SparkConnectClient {
/// Create a client to use GRPCClient.
/// - Parameters:
/// - remote: A string to connect `Spark Connect` server.
/// - user: A string for the user ID of this connection.
init(remote: String, user: String, token: String? = nil) {
init(remote: String) {
self.url = URL(string: remote)!
self.host = url.host() ?? "localhost"
self.port = self.url.port ?? 15002
var token: String? = nil
let processInfo = ProcessInfo.processInfo
#if os(macOS) || os(Linux)
var userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
#else
var userName = processInfo.environment["SPARK_USER"] ?? ""
#endif
for param in self.url.path.split(separator: ";").dropFirst().filter({ !$0.isEmpty }) {
let kv = param.split(separator: "=")
switch String(kv[0]) {
case URIParams.PARAM_USER_AGENT:
clientType = String(kv[1])
case URIParams.PARAM_TOKEN:
token = String(kv[1])
case URIParams.PARAM_USER_ID:
userName = String(kv[1])
default:
// Print warning and ignore
print("Unknown parameter: \(param)")
}
}
self.token = token ?? ProcessInfo.processInfo.environment["SPARK_CONNECT_AUTHENTICATE_TOKEN"]
if let token = self.token {
self.intercepters.append(BearerTokenInterceptor(token: token))
}
self.userContext = user.toUserContext
self.userContext = userName.toUserContext
}

/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
Expand Down Expand Up @@ -574,4 +594,13 @@ public actor SparkConnectClient {
return OneOf_Analyze.isStreaming(isStreaming)
})
}

private enum URIParams {
static let PARAM_USER_ID = "userId"
static let PARAM_USER_AGENT = "userAgent"
static let PARAM_TOKEN = "token"
static let PARAM_USE_SSL = "useSsl"
static let PARAM_SESSION_ID = "sessionId"
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpcMaxMessageSize"
}
}
12 changes: 2 additions & 10 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
// under the License.
//

import Dispatch
import Foundation

/// The entry point to programming Spark with ``DataFrame`` API.
Expand All @@ -39,15 +38,8 @@ public actor SparkSession {
/// Create a session that uses the specified connection string and userID.
/// - Parameters:
/// - connection: a string in a patter, `sc://{host}:{port}`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify possible parameters for connection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for review. For that part, I'll document at the final step. Currently, I'm still working on the remaining params like useSsl (security), sessionId (requires reconnection support), grpcMaxMessageSize.

/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
init(_ connection: String, _ userID: String? = nil) {
let processInfo = ProcessInfo.processInfo
#if os(macOS) || os(Linux)
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
#else
let userName = processInfo.environment["SPARK_USER"] ?? ""
#endif
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
init(_ connection: String) {
self.client = SparkConnectClient(remote: connection)
self.conf = RuntimeConf(self.client)
}

Expand Down
8 changes: 4 additions & 4 deletions Tests/SparkConnectTests/RuntimeConfTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Testing
struct RuntimeConfTests {
@Test
func get() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
_ = try await client.connect(UUID().uuidString)
let conf = RuntimeConf(client)

Expand All @@ -42,7 +42,7 @@ struct RuntimeConfTests {

@Test
func set() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
_ = try await client.connect(UUID().uuidString)
let conf = RuntimeConf(client)
try await conf.set("spark.test.key1", "value1")
Expand All @@ -52,7 +52,7 @@ struct RuntimeConfTests {

@Test
func reset() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
_ = try await client.connect(UUID().uuidString)
let conf = RuntimeConf(client)

Expand All @@ -73,7 +73,7 @@ struct RuntimeConfTests {

@Test
func getAll() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
_ = try await client.connect(UUID().uuidString)
let conf = RuntimeConf(client)
let map = try await conf.getAll()
Expand Down
23 changes: 17 additions & 6 deletions Tests/SparkConnectTests/SparkConnectClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ import Testing
struct SparkConnectClientTests {
@Test
func createAndStop() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
await client.stop()
}

@Test
func parameters() async throws {
let client = SparkConnectClient(remote: "sc://host1:123/;token=abcd;userId=test;userAgent=myagent")
#expect(await client.token == "abcd")
#expect(await client.userContext.userID == "test")
#expect(await client.clientType == "myagent")
#expect(await client.host == "host1")
#expect(await client.port == 123)
await client.stop()
}

@Test
func connectWithInvalidUUID() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
try await #require(throws: SparkConnectError.InvalidSessionIDException) {
let _ = try await client.connect("not-a-uuid-format")
}
Expand All @@ -42,14 +53,14 @@ struct SparkConnectClientTests {

@Test
func connect() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
let _ = try await client.connect(UUID().uuidString)
await client.stop()
}

@Test
func tags() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
let _ = try await client.connect(UUID().uuidString)
let plan = await client.getPlanRange(0, 1, 1)

Expand All @@ -65,7 +76,7 @@ struct SparkConnectClientTests {

@Test
func ddlParse() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
let _ = try await client.connect(UUID().uuidString)
#expect(try await client.ddlParse("a int").simpleString == "struct<a:int>")
await client.stop()
Expand All @@ -74,7 +85,7 @@ struct SparkConnectClientTests {
#if !os(Linux) // TODO: Enable this with the offical Spark 4 docker image
@Test
func jsonToDdl() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let client = SparkConnectClient(remote: "sc://localhost")
let _ = try await client.connect(UUID().uuidString)
let json =
#"{"type":"struct","fields":[{"name":"id","type":"long","nullable":false,"metadata":{}}]}"#
Expand Down
Loading