diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 8732c9e..ae7503a 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -62,6 +62,12 @@ public actor DataFrame: Sendable { self.batches.append(contentsOf: batches) } + /// Return the `SparkSession` of this `DataFrame`. + /// - Returns: A `SparkSession` + public func sparkSession() -> SparkSession { + return self.spark + } + /// A method to access the underlying Spark's `RDD`. /// In `Spark Connect`, this feature is not allowed by design. public func rdd() throws { diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 848a96e..da330cc 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -77,3 +77,9 @@ extension Data { /// Get an `Int32` value from unsafe 4 bytes. var int32: Int32 { withUnsafeBytes({ $0.load(as: Int32.self) }) } } + +extension SparkSession: Equatable { + public static func == (lhs: SparkSession, rhs: SparkSession) -> Bool { + return lhs.sessionID == rhs.sessionID + } +} diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index e68deef..524f46a 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -62,7 +62,7 @@ public actor SparkSession { } /// A unique session ID for this session from client. - var sessionID: String = UUID().uuidString + nonisolated let sessionID: String = UUID().uuidString /// Get the current session ID /// - Returns: the current session ID diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index c7170d3..c7f5ba0 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -23,6 +23,13 @@ import Testing /// A test suite for `DataFrame` struct DataFrameTests { + @Test + func sparkSession() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(1).sparkSession() == spark) + await spark.stop() + } + @Test func rdd() async throws { let spark = try await SparkSession.builder.getOrCreate()