diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index b96e02d..96c36be 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -342,6 +342,21 @@ public actor DataFrame: Sendable { return self } + var storageLevel: StorageLevel { + get async throws { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + return try await service + .analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel + } + } + } + public func explain() async throws { try await explain("simple") } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 9dfde0d..3314c55 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -282,6 +282,17 @@ public actor SparkConnectClient { }) } + func getStorageLevel(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var level = AnalyzePlanRequest.GetStorageLevel() + level.relation = plan.root + return OneOf_Analyze.getStorageLevel(level) + }) + } + func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async -> AnalyzePlanRequest { return analyze( diff --git a/Sources/SparkConnect/StorageLevel.swift b/Sources/SparkConnect/StorageLevel.swift index c2e6f04..524b507 100644 --- a/Sources/SparkConnect/StorageLevel.swift +++ b/Sources/SparkConnect/StorageLevel.swift @@ -78,6 +78,12 @@ extension StorageLevel { level.replication = self.replication return level } + + public static func == (lhs: StorageLevel, rhs: StorageLevel) -> Bool { + return lhs.useDisk == rhs.useDisk && lhs.useMemory == rhs.useMemory + && lhs.useOffHeap == rhs.useOffHeap && lhs.deserialized == rhs.deserialized + && lhs.replication == rhs.replication + } } extension StorageLevel: CustomStringConvertible { @@ -86,3 +92,15 @@ extension StorageLevel: CustomStringConvertible { "StorageLevel(useDisk: \(useDisk), useMemory: \(useMemory), useOffHeap: \(useOffHeap), deserialized: \(deserialized), replication: \(replication))" } } + +extension Spark_Connect_StorageLevel { + var toStorageLevel: StorageLevel { + return StorageLevel( + useDisk: self.useDisk, + useMemory: self.useMemory, + useOffHeap: self.useOffHeap, + deserialized: self.deserialized, + replication: self.replication + ) + } +} diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index f9dd37e..7e903bd 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -308,4 +308,22 @@ struct DataFrameTests { await spark.stop() } #endif + + @Test + func storageLevel() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1) + + _ = try await df.unpersist() + #expect(try await df.storageLevel == StorageLevel.NONE) + _ = try await df.persist() + #expect(try await df.storageLevel == StorageLevel.MEMORY_AND_DISK) + + _ = try await df.unpersist() + #expect(try await df.storageLevel == StorageLevel.NONE) + _ = try await df.persist(storageLevel: StorageLevel.MEMORY_ONLY) + #expect(try await df.storageLevel == StorageLevel.MEMORY_ONLY) + + await spark.stop() + } }