diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index c67e30e..cb1c2e1 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -52,6 +52,9 @@ public actor SparkConnectClient { for param in self.url.path.split(separator: ";").dropFirst().filter({ !$0.isEmpty }) { let kv = param.split(separator: "=") switch String(kv[0]).lowercased() { + case URIParams.PARAM_SESSION_ID: + // SparkSession handles this. + break case URIParams.PARAM_USER_AGENT: clientType = String(kv[1]) case URIParams.PARAM_TOKEN: diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index e565126..ed25f5e 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -35,11 +35,19 @@ public actor SparkSession { /// Runtime configuration interface for Spark. public let conf: RuntimeConf + let regexSessionID = /;session_id=([a-zA-Z0-9-]+)/ + /// Create a session that uses the specified connection string and userID. /// - Parameters: /// - connection: a string in a patter, `sc://{host}:{port}` init(_ connection: String) { self.client = SparkConnectClient(remote: connection) + // Since `Session ID` belongs to `SparkSession`, we handle this here. + if connection.contains(regexSessionID) { + self.sessionID = connection.firstMatch(of: regexSessionID)!.1.uppercased() + } else { + self.sessionID = UUID().uuidString + } self.conf = RuntimeConf(self.client) } @@ -58,7 +66,7 @@ public actor SparkSession { } /// A unique session ID for this session from client. - nonisolated let sessionID: String = UUID().uuidString + nonisolated let sessionID: String /// Get the current session ID /// - Returns: the current session ID diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index c2b5aa2..5124c72 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -50,6 +50,17 @@ struct SparkSessionTests { await newSpark.stop() } + @Test + func sessionID() async throws { + let spark1 = try await SparkSession.builder.getOrCreate() + await spark1.stop() + let remote = "sc://localhost/;session_id=\(spark1.sessionID)" + let spark2 = try await SparkSession.builder.remote(remote).getOrCreate() + await spark2.stop() + #expect(spark1.sessionID == spark2.sessionID) + #expect(spark1 == spark2) + } + @Test func userContext() async throws { let spark = try await SparkSession.builder.getOrCreate() #if os(macOS) || os(Linux)