Skip to content

Commit c1817d7

Browse files
committed
[SPARK-51693] Support storageLevel for DataFrame
### What changes were proposed in this pull request? This PR aims to support `DataFrame.storageLevel`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs. ``` $ swift test --filter DataFrameTests.storageLevel 􀟈 Suite DataFrameTests started. 􀟈 Test storageLevel() started. 􁁛 Test storageLevel() passed after 0.075 seconds. 􁁛 Suite DataFrameTests passed after 0.075 seconds. 􁁛 Test run with 1 test passed after 0.075 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #38 from dongjoon-hyun/SPARK-51693. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 33b4b08 commit c1817d7

File tree

4 files changed

+62
-0
lines changed

4 files changed

+62
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,21 @@ public actor DataFrame: Sendable {
342342
return self
343343
}
344344

345+
var storageLevel: StorageLevel {
346+
get async throws {
347+
try await withGRPCClient(
348+
transport: .http2NIOPosix(
349+
target: .dns(host: spark.client.host, port: spark.client.port),
350+
transportSecurity: .plaintext
351+
)
352+
) { client in
353+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
354+
return try await service
355+
.analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel
356+
}
357+
}
358+
}
359+
345360
public func explain() async throws {
346361
try await explain("simple")
347362
}

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,17 @@ public actor SparkConnectClient {
282282
})
283283
}
284284

285+
func getStorageLevel(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest
286+
{
287+
return analyze(
288+
sessionID,
289+
{
290+
var level = AnalyzePlanRequest.GetStorageLevel()
291+
level.relation = plan.root
292+
return OneOf_Analyze.getStorageLevel(level)
293+
})
294+
}
295+
285296
func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async -> AnalyzePlanRequest
286297
{
287298
return analyze(

Sources/SparkConnect/StorageLevel.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ extension StorageLevel {
7878
level.replication = self.replication
7979
return level
8080
}
81+
82+
public static func == (lhs: StorageLevel, rhs: StorageLevel) -> Bool {
83+
return lhs.useDisk == rhs.useDisk && lhs.useMemory == rhs.useMemory
84+
&& lhs.useOffHeap == rhs.useOffHeap && lhs.deserialized == rhs.deserialized
85+
&& lhs.replication == rhs.replication
86+
}
8187
}
8288

8389
extension StorageLevel: CustomStringConvertible {
@@ -86,3 +92,15 @@ extension StorageLevel: CustomStringConvertible {
8692
"StorageLevel(useDisk: \(useDisk), useMemory: \(useMemory), useOffHeap: \(useOffHeap), deserialized: \(deserialized), replication: \(replication))"
8793
}
8894
}
95+
96+
extension Spark_Connect_StorageLevel {
97+
var toStorageLevel: StorageLevel {
98+
return StorageLevel(
99+
useDisk: self.useDisk,
100+
useMemory: self.useMemory,
101+
useOffHeap: self.useOffHeap,
102+
deserialized: self.deserialized,
103+
replication: self.replication
104+
)
105+
}
106+
}

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,4 +308,22 @@ struct DataFrameTests {
308308
await spark.stop()
309309
}
310310
#endif
311+
312+
@Test
313+
func storageLevel() async throws {
314+
let spark = try await SparkSession.builder.getOrCreate()
315+
let df = try await spark.range(1)
316+
317+
_ = try await df.unpersist()
318+
#expect(try await df.storageLevel == StorageLevel.NONE)
319+
_ = try await df.persist()
320+
#expect(try await df.storageLevel == StorageLevel.MEMORY_AND_DISK)
321+
322+
_ = try await df.unpersist()
323+
#expect(try await df.storageLevel == StorageLevel.NONE)
324+
_ = try await df.persist(storageLevel: StorageLevel.MEMORY_ONLY)
325+
#expect(try await df.storageLevel == StorageLevel.MEMORY_ONLY)
326+
327+
await spark.stop()
328+
}
311329
}

0 commit comments

Comments
 (0)