Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public actor SparkConnectClient {
for param in self.url.path.split(separator: ";").dropFirst().filter({ !$0.isEmpty }) {
let kv = param.split(separator: "=")
switch String(kv[0]).lowercased() {
case URIParams.PARAM_SESSION_ID:
// SparkSession handles this.
break
case URIParams.PARAM_USER_AGENT:
clientType = String(kv[1])
case URIParams.PARAM_TOKEN:
Expand Down
10 changes: 9 additions & 1 deletion Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,19 @@ public actor SparkSession {
/// Runtime configuration interface for Spark.
public let conf: RuntimeConf

let regexSessionID = /;session_id=([a-zA-Z0-9-]+)/

/// Create a session that uses the specified connection string and userID.
/// - Parameters:
/// - connection: a string in a patter, `sc://{host}:{port}`
init(_ connection: String) {
self.client = SparkConnectClient(remote: connection)
// Since `Session ID` belongs to `SparkSession`, we handle this here.
if connection.contains(regexSessionID) {
self.sessionID = connection.firstMatch(of: regexSessionID)!.1.uppercased()
} else {
self.sessionID = UUID().uuidString
}
self.conf = RuntimeConf(self.client)
}

Expand All @@ -58,7 +66,7 @@ public actor SparkSession {
}

/// A unique session ID for this session from client.
nonisolated let sessionID: String = UUID().uuidString
nonisolated let sessionID: String

/// Get the current session ID
/// - Returns: the current session ID
Expand Down
11 changes: 11 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ struct SparkSessionTests {
await newSpark.stop()
}

@Test
func sessionID() async throws {
let spark1 = try await SparkSession.builder.getOrCreate()
await spark1.stop()
let remote = "sc://localhost/;session_id=\(spark1.sessionID)"
let spark2 = try await SparkSession.builder.remote(remote).getOrCreate()
await spark2.stop()
#expect(spark1.sessionID == spark2.sessionID)
#expect(spark1 == spark2)
}

@Test func userContext() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#if os(macOS) || os(Linux)
Expand Down
Loading