Skip to content

Commit 2000073

Browse files
committed
[SPARK-51560] Support cache/persist/unpersist for DataFrame
1 parent 968b77c commit 2000073

File tree

4 files changed

+103
-0
lines changed

4 files changed

+103
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,38 @@ public actor DataFrame: Sendable {
245245
public func isEmpty() async throws -> Bool {
246246
return try await select().limit(1).count() == 0
247247
}
248+
249+
public func cache() async throws -> DataFrame {
250+
return try await persist()
251+
}
252+
253+
public func persist(useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false, deserialized: Bool = true, replication: Int32 = 1)
254+
async throws -> DataFrame
255+
{
256+
try await withGRPCClient(
257+
transport: .http2NIOPosix(
258+
target: .dns(host: spark.client.host, port: spark.client.port),
259+
transportSecurity: .plaintext
260+
)
261+
) { client in
262+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
263+
_ = try await service.analyzePlan(spark.client.getPersist(spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication))
264+
}
265+
266+
return self
267+
}
268+
269+
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
270+
try await withGRPCClient(
271+
transport: .http2NIOPosix(
272+
target: .dns(host: spark.client.host, port: spark.client.port),
273+
transportSecurity: .plaintext
274+
)
275+
) { client in
276+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
277+
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
278+
}
279+
280+
return self
281+
}
248282
}

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,41 @@ public actor SparkConnectClient {
256256
return request
257257
}
258258

259+
func getPersist(
260+
_ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true,
261+
_ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1
262+
) async
263+
-> AnalyzePlanRequest
264+
{
265+
return analyze(
266+
sessionID,
267+
{
268+
var persist = AnalyzePlanRequest.Persist()
269+
var level = StorageLevel()
270+
level.useDisk = useDisk
271+
level.useMemory = useMemory
272+
level.useOffHeap = useOffHeap
273+
level.deserialized = deserialized
274+
level.replication = replication
275+
persist.storageLevel = level
276+
persist.relation = plan.root
277+
return OneOf_Analyze.persist(persist)
278+
})
279+
}
280+
281+
func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = false) async
282+
-> AnalyzePlanRequest
283+
{
284+
return analyze(
285+
sessionID,
286+
{
287+
var unpersist = AnalyzePlanRequest.Unpersist()
288+
unpersist.relation = plan.root
289+
unpersist.blocking = blocking
290+
return OneOf_Analyze.unpersist(unpersist)
291+
})
292+
}
293+
259294
static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
260295
var project = Project()
261296
project.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range
3030
typealias Relation = Spark_Connect_Relation
3131
typealias SparkConnectService = Spark_Connect_SparkConnectService
3232
typealias Sort = Spark_Connect_Sort
33+
typealias StorageLevel = Spark_Connect_StorageLevel
3334
typealias UserContext = Spark_Connect_UserContext
3435
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,37 @@ struct DataFrameTests {
194194
await spark.stop()
195195
}
196196
#endif
197+
198+
@Test
199+
func cache() async throws {
200+
let spark = try await SparkSession.builder.getOrCreate()
201+
#expect(try await spark.range(10).cache().count() == 10)
202+
await spark.stop()
203+
}
204+
205+
@Test
206+
func persist() async throws {
207+
let spark = try await SparkSession.builder.getOrCreate()
208+
#expect(try await spark.range(20).persist().count() == 20)
209+
#expect(try await spark.range(21).persist(useDisk: false).count() == 21)
210+
await spark.stop()
211+
}
212+
213+
@Test
214+
func persistInvalidStorageLevel() async throws {
215+
let spark = try await SparkSession.builder.getOrCreate()
216+
try await #require(throws: Error.self) {
217+
let _ = try await spark.range(9999).persist(replication: 0).count()
218+
}
219+
await spark.stop()
220+
}
221+
222+
@Test
223+
func unpersist() async throws {
224+
let spark = try await SparkSession.builder.getOrCreate()
225+
let df = try await spark.range(30)
226+
#expect(try await df.persist().count() == 30)
227+
#expect(try await df.unpersist().count() == 30)
228+
await spark.stop()
229+
}
197230
}

0 commit comments

Comments
 (0)