Skip to content

Commit 209e93e

Browse files
committed
[SPARK-51560] Support cache/persist/unpersist for DataFrame
### What changes were proposed in this pull request? This PR aims to support `cache`, `persist`, and `unpersist` for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ``` $ swift test --filter DataFrameTests ... 􀟈 Test run started. 􀄵 Testing Library Version: 102 (arm64e-apple-macos13.0) 􀟈 Suite DataFrameTests started. 􀟈 Test orderBy() started. 􀟈 Test isEmpty() started. 􀟈 Test show() started. 􀟈 Test persist() started. 􀟈 Test showCommand() started. 􀟈 Test table() started. 􀟈 Test selectMultipleColumns() started. 􀟈 Test showNull() started. 􀟈 Test schema() started. 􀟈 Test selectNone() started. 􀟈 Test rdd() started. 􀟈 Test sort() started. 􀟈 Test unpersist() started. 􀟈 Test limit() started. 􀟈 Test count() started. 􀟈 Test cache() started. 􀟈 Test selectInvalidColumn() started. 􀟈 Test collect() started. 􀟈 Test countNull() started. 􀟈 Test select() started. 􀟈 Test persistInvalidStorageLevel() started. 􁁛 Test rdd() passed after 0.571 seconds. 􁁛 Test selectNone() passed after 1.347 seconds. 􁁛 Test select() passed after 1.354 seconds. 􁁛 Test selectMultipleColumns() passed after 1.354 seconds. 􁁛 Test selectInvalidColumn() passed after 1.395 seconds. 􁁛 Test schema() passed after 1.747 seconds. ++ || ++ ++ 􁁛 Test showCommand() passed after 1.885 seconds. +-----------+-----------+-------------+ | namespace | tableName | isTemporary | +-----------+-----------+-------------+ +-----------+-----------+-------------+ +------+-------+------+ | col1 | col2 | col3 | +------+-------+------+ | 1 | true | abc | | NULL | NULL | NULL | | 3 | false | def | +------+-------+------+ 􁁛 Test showNull() passed after 1.890 seconds. +------+-------+ | col1 | col2 | +------+-------+ | true | false | +------+-------+ +------+------+ | col1 | col2 | +------+------+ | 1 | 2 | +------+------+ +------+------+ | col1 | col2 | +------+------+ | abc | def | | ghi | jkl | +------+------+ 􁁛 Test show() passed after 1.975 seconds. 􁁛 Test collect() passed after 2.045 seconds. 􁁛 Test countNull() passed after 2.566 seconds. 􁁛 Test persistInvalidStorageLevel() passed after 2.578 seconds. 􁁛 Test cache() passed after 2.683 seconds. 􁁛 Test isEmpty() passed after 2.778 seconds. 􁁛 Test count() passed after 2.892 seconds. 􁁛 Test persist() passed after 2.903 seconds. 􁁛 Test unpersist() passed after 2.917 seconds. 􁁛 Test limit() passed after 3.068 seconds. 􁁛 Test orderBy() passed after 3.101 seconds. 􁁛 Test sort() passed after 3.102 seconds. 􁁛 Test table() passed after 3.720 seconds. 􁁛 Suite DataFrameTests passed after 3.720 seconds. 􁁛 Test run with 21 tests passed after 3.720 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #22 from dongjoon-hyun/SPARK-51560. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 968b77c commit 209e93e

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,43 @@ public actor DataFrame: Sendable {
245245
public func isEmpty() async throws -> Bool {
246246
return try await select().limit(1).count() == 0
247247
}
248+
249+
public func cache() async throws -> DataFrame {
250+
return try await persist()
251+
}
252+
253+
public func persist(
254+
useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false,
255+
deserialized: Bool = true, replication: Int32 = 1
256+
)
257+
async throws -> DataFrame
258+
{
259+
try await withGRPCClient(
260+
transport: .http2NIOPosix(
261+
target: .dns(host: spark.client.host, port: spark.client.port),
262+
transportSecurity: .plaintext
263+
)
264+
) { client in
265+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
266+
_ = try await service.analyzePlan(
267+
spark.client.getPersist(
268+
spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication))
269+
}
270+
271+
return self
272+
}
273+
274+
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
275+
try await withGRPCClient(
276+
transport: .http2NIOPosix(
277+
target: .dns(host: spark.client.host, port: spark.client.port),
278+
transportSecurity: .plaintext
279+
)
280+
) { client in
281+
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
282+
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
283+
}
284+
285+
return self
286+
}
248287
}

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,41 @@ public actor SparkConnectClient {
256256
return request
257257
}
258258

259+
func getPersist(
260+
_ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true,
261+
_ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1
262+
) async
263+
-> AnalyzePlanRequest
264+
{
265+
return analyze(
266+
sessionID,
267+
{
268+
var persist = AnalyzePlanRequest.Persist()
269+
var level = StorageLevel()
270+
level.useDisk = useDisk
271+
level.useMemory = useMemory
272+
level.useOffHeap = useOffHeap
273+
level.deserialized = deserialized
274+
level.replication = replication
275+
persist.storageLevel = level
276+
persist.relation = plan.root
277+
return OneOf_Analyze.persist(persist)
278+
})
279+
}
280+
281+
func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = false) async
282+
-> AnalyzePlanRequest
283+
{
284+
return analyze(
285+
sessionID,
286+
{
287+
var unpersist = AnalyzePlanRequest.Unpersist()
288+
unpersist.relation = plan.root
289+
unpersist.blocking = blocking
290+
return OneOf_Analyze.unpersist(unpersist)
291+
})
292+
}
293+
259294
static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
260295
var project = Project()
261296
project.input = child

Sources/SparkConnect/TypeAliases.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range
3030
typealias Relation = Spark_Connect_Relation
3131
typealias SparkConnectService = Spark_Connect_SparkConnectService
3232
typealias Sort = Spark_Connect_Sort
33+
typealias StorageLevel = Spark_Connect_StorageLevel
3334
typealias UserContext = Spark_Connect_UserContext
3435
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,38 @@ struct DataFrameTests {
193193
try await spark.sql("DROP TABLE IF EXISTS t").show()
194194
await spark.stop()
195195
}
196+
197+
@Test
198+
func cache() async throws {
199+
let spark = try await SparkSession.builder.getOrCreate()
200+
#expect(try await spark.range(10).cache().count() == 10)
201+
await spark.stop()
202+
}
203+
204+
@Test
205+
func persist() async throws {
206+
let spark = try await SparkSession.builder.getOrCreate()
207+
#expect(try await spark.range(20).persist().count() == 20)
208+
#expect(try await spark.range(21).persist(useDisk: false).count() == 21)
209+
await spark.stop()
210+
}
211+
212+
@Test
213+
func persistInvalidStorageLevel() async throws {
214+
let spark = try await SparkSession.builder.getOrCreate()
215+
try await #require(throws: Error.self) {
216+
let _ = try await spark.range(9999).persist(replication: 0).count()
217+
}
218+
await spark.stop()
219+
}
220+
221+
@Test
222+
func unpersist() async throws {
223+
let spark = try await SparkSession.builder.getOrCreate()
224+
let df = try await spark.range(30)
225+
#expect(try await df.persist().count() == 30)
226+
#expect(try await df.unpersist().count() == 30)
227+
await spark.stop()
228+
}
196229
#endif
197230
}

0 commit comments

Comments
 (0)