Skip to content

Commit c03cbfe

Browse files
committed
[SPARK-51636] Add StorageLevel struct
### What changes were proposed in this pull request? This PR aims to add a public `StorageLevel` struct. ### Why are the changes needed? To provide a similar interface without exposing protobuf-generated data structures. ### Does this PR introduce _any_ user-facing change? Although the APIs are changed, this is a revision to the unreleased version. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #30 from dongjoon-hyun/SPARK-51636. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent f97822d commit c03cbfe

File tree

5 files changed

+99
-21
lines changed

5 files changed

+99
-21
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,10 @@ public actor DataFrame: Sendable {
294294
return try await persist()
295295
}
296296

297-
public func persist(
298-
useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false,
299-
deserialized: Bool = true, replication: Int32 = 1
300-
)
301-
async throws -> DataFrame
297+
/// Persist this `DataFrame` with the given storage level.
298+
/// - Parameter storageLevel: A storage level to apply.
299+
public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws
300+
-> DataFrame
302301
{
303302
try await withGRPCClient(
304303
transport: .http2NIOPosix(
@@ -308,8 +307,7 @@ public actor DataFrame: Sendable {
308307
) { client in
309308
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
310309
_ = try await service.analyzePlan(
311-
spark.client.getPersist(
312-
spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication))
310+
spark.client.getPersist(spark.sessionID, plan, storageLevel))
313311
}
314312

315313
return self

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,23 +256,14 @@ public actor SparkConnectClient {
256256
return request
257257
}
258258

259-
func getPersist(
260-
_ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true,
261-
_ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1
262-
) async
259+
func getPersist(_ sessionID: String, _ plan: Plan, _ storageLevel: StorageLevel) async
263260
-> AnalyzePlanRequest
264261
{
265262
return analyze(
266263
sessionID,
267264
{
268265
var persist = AnalyzePlanRequest.Persist()
269-
var level = StorageLevel()
270-
level.useDisk = useDisk
271-
level.useMemory = useMemory
272-
level.useOffHeap = useOffHeap
273-
level.deserialized = deserialized
274-
level.replication = replication
275-
persist.storageLevel = level
266+
persist.storageLevel = storageLevel.toSparkConnectStorageLevel
276267
persist.relation = plan.root
277268
return OneOf_Analyze.persist(persist)
278269
})
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
/// Flags for controlling the storage of an `RDD`. Each ``StorageLevel`` records whether to use memory,
21+
/// or `ExternalBlockStore`, whether to drop the `RDD` to disk if it falls out of memory or
22+
/// `ExternalBlockStore`, whether to keep the data in memory in a serialized format, and whether
23+
/// to replicate the `RDD` partitions on multiple nodes.
24+
public struct StorageLevel: Sendable {
25+
/// Whether the cache should use disk or not.
26+
var useDisk: Bool
27+
28+
/// Whether the cache should use memory or not.
29+
var useMemory: Bool
30+
31+
/// Whether the cache should use off-heap or not.
32+
var useOffHeap: Bool
33+
34+
/// Whether the cached data is deserialized or not.
35+
var deserialized: Bool
36+
37+
/// The number of replicas.
38+
var replication: Int32
39+
40+
init(useDisk: Bool, useMemory: Bool, useOffHeap: Bool, deserialized: Bool, replication: Int32 = 1)
41+
{
42+
self.useDisk = useDisk
43+
self.useMemory = useMemory
44+
self.useOffHeap = useOffHeap
45+
self.deserialized = deserialized
46+
self.replication = replication
47+
}
48+
49+
public static let NONE = StorageLevel(
50+
useDisk: false, useMemory: false, useOffHeap: false, deserialized: false)
51+
public static let DISK_ONLY = StorageLevel(
52+
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false)
53+
public static let DISK_ONLY_2 = StorageLevel(
54+
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false, replication: 2)
55+
public static let DISK_ONLY_3 = StorageLevel(
56+
useDisk: true, useMemory: false, useOffHeap: false, deserialized: false, replication: 3)
57+
public static let MEMORY_ONLY = StorageLevel(
58+
useDisk: false, useMemory: true, useOffHeap: false, deserialized: false)
59+
public static let MEMORY_ONLY_2 = StorageLevel(
60+
useDisk: false, useMemory: true, useOffHeap: false, deserialized: false, replication: 2)
61+
public static let MEMORY_AND_DISK = StorageLevel(
62+
useDisk: true, useMemory: true, useOffHeap: false, deserialized: false)
63+
public static let MEMORY_AND_DISK_2 = StorageLevel(
64+
useDisk: true, useMemory: true, useOffHeap: false, deserialized: false, replication: 2)
65+
public static let OFF_HEAP = StorageLevel(
66+
useDisk: true, useMemory: true, useOffHeap: true, deserialized: false)
67+
public static let MEMORY_AND_DISK_DESER = StorageLevel(
68+
useDisk: true, useMemory: true, useOffHeap: false, deserialized: true)
69+
}
70+
71+
extension StorageLevel {
72+
var toSparkConnectStorageLevel: Spark_Connect_StorageLevel {
73+
var level = Spark_Connect_StorageLevel()
74+
level.useDisk = self.useDisk
75+
level.useMemory = self.useMemory
76+
level.useOffHeap = self.useOffHeap
77+
level.deserialized = self.deserialized
78+
level.replication = self.replication
79+
return level
80+
}
81+
}
82+
83+
extension StorageLevel: CustomStringConvertible {
84+
public var description: String {
85+
return
86+
"StorageLevel(useDisk: \(useDisk), useMemory: \(useMemory), useOffHeap: \(useOffHeap), deserialized: \(deserialized), replication: \(replication))"
87+
}
88+
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,5 @@ typealias Read = Spark_Connect_Read
3434
typealias Relation = Spark_Connect_Relation
3535
typealias SparkConnectService = Spark_Connect_SparkConnectService
3636
typealias Sort = Spark_Connect_Sort
37-
typealias StorageLevel = Spark_Connect_StorageLevel
3837
typealias UserContext = Spark_Connect_UserContext
3938
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,17 @@ struct DataFrameTests {
236236
func persist() async throws {
237237
let spark = try await SparkSession.builder.getOrCreate()
238238
#expect(try await spark.range(20).persist().count() == 20)
239-
#expect(try await spark.range(21).persist(useDisk: false).count() == 21)
239+
#expect(try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21)
240240
await spark.stop()
241241
}
242242

243243
@Test
244244
func persistInvalidStorageLevel() async throws {
245245
let spark = try await SparkSession.builder.getOrCreate()
246246
try await #require(throws: Error.self) {
247-
let _ = try await spark.range(9999).persist(replication: 0).count()
247+
var invalidLevel = StorageLevel.DISK_ONLY
248+
invalidLevel.replication = 0
249+
let _ = try await spark.range(9999).persist(storageLevel: invalidLevel).count()
248250
}
249251
await spark.stop()
250252
}

0 commit comments

Comments
 (0)