Skip to content

Commit aa61fff

Browse files
committed
[SPARK-52172] Add checkpoint and localCheckpoint for DataFrame
### What changes were proposed in this pull request? This PR aims to add `checkpoint` and `localCheckpoint` APIs to `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. I manually checked the UI. ### Was this patch authored or co-authored using generative AI tooling? Pass the CIs. Closes #157 from dongjoon-hyun/SPARK-52172. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent c386e26 commit aa61fff

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ import Synchronization
151151
///
152152
/// ### Persistence
153153
/// - ``cache()``
154+
/// - ``checkpoint(_:_:_:)``
155+
/// - ``localCheckpoint(_:_:)``
154156
/// - ``persist(storageLevel:)``
155157
/// - ``unpersist(blocking:)``
156158
/// - ``storageLevel``
@@ -1407,6 +1409,41 @@ public actor DataFrame: Sendable {
14071409
try await spark.client.createTempView(self.plan.root, viewName, replace: replace, isGlobal: global)
14081410
}
14091411

1412+
/// Eagerly checkpoint a ``DataFrame`` and return the new ``DataFrame``.
1413+
/// Checkpointing can be used to truncate the logical plan of this ``DataFrame``,
1414+
/// which is especially useful in iterative algorithms where the plan may grow exponentially.
1415+
/// It will be saved to files inside the checkpoint directory.
1416+
/// - Parameters:
1417+
/// - eager: Whether to checkpoint this dataframe immediately
1418+
/// - reliableCheckpoint: Whether to create a reliable checkpoint saved to files inside the checkpoint directory.
1419+
/// If false creates a local checkpoint using the caching subsystem
1420+
/// - storageLevel: StorageLevel with which to checkpoint the data.
1421+
/// - Returns: A ``DataFrame``.
1422+
public func checkpoint(
1423+
_ eager: Bool = true,
1424+
_ reliableCheckpoint: Bool = true,
1425+
_ storageLevel: StorageLevel? = nil
1426+
) async throws -> DataFrame {
1427+
let plan = try await spark.client.getCheckpoint(self.plan.root, eager, reliableCheckpoint, storageLevel)
1428+
return DataFrame(spark: self.spark, plan: plan)
1429+
}
1430+
1431+
/// Locally checkpoints a ``DataFrame`` and return the new ``DataFrame``.
1432+
/// Checkpointing can be used to truncate the logical plan of this ``DataFrame``,
1433+
/// which is especially useful in iterative algorithms where the plan may grow exponentially.
1434+
/// Local checkpoints are written to executor storage and despite potentially faster they
1435+
/// are unreliable and may compromise job completion.
1436+
/// - Parameters:
1437+
/// - eager: Whether to checkpoint this dataframe immediately
1438+
/// - storageLevel: StorageLevel with which to checkpoint the data.
1439+
/// - Returns: A ``DataFrame``.
1440+
public func localCheckpoint(
1441+
_ eager: Bool = true,
1442+
_ storageLevel: StorageLevel? = nil
1443+
) async throws -> DataFrame {
1444+
try await checkpoint(eager, false, storageLevel)
1445+
}
1446+
14101447
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
14111448
public var write: DataFrameWriter {
14121449
get {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,32 @@ public actor SparkConnectClient {
996996
return plan
997997
}
998998

999+
func getCheckpoint(
1000+
_ child: Relation,
1001+
_ eager: Bool,
1002+
_ reliableCheckpoint: Bool,
1003+
_ storageLevel: StorageLevel?
1004+
) async throws -> Plan {
1005+
var checkpointCommand = Spark_Connect_CheckpointCommand()
1006+
checkpointCommand.eager = eager
1007+
checkpointCommand.local = !reliableCheckpoint
1008+
checkpointCommand.relation = child
1009+
if let storageLevel {
1010+
checkpointCommand.storageLevel = storageLevel.toSparkConnectStorageLevel
1011+
}
1012+
1013+
var command = Spark_Connect_Command()
1014+
command.checkpointCommand = checkpointCommand
1015+
let response = try await execute(self.sessionID!, command)
1016+
let cachedRemoteRelation = response.first!.checkpointCommandResult.relation
1017+
1018+
var relation = Relation()
1019+
relation.cachedRemoteRelation = cachedRemoteRelation
1020+
var plan = Plan()
1021+
plan.opType = .root(relation)
1022+
return plan
1023+
}
1024+
9991025
func createTempView(
10001026
_ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool
10011027
) async throws {

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,30 @@ struct DataFrameTests {
481481
await spark.stop()
482482
}
483483

484+
@Test
485+
func checkpoint() async throws {
486+
let spark = try await SparkSession.builder.getOrCreate()
487+
if await spark.version >= "4.0.0" {
488+
// By default, reliable checkpoint location is required.
489+
try await #require(throws: Error.self) {
490+
try await spark.range(10).checkpoint()
491+
}
492+
// Checkpointing with unreliable checkpoint
493+
let df = try await spark.range(10).checkpoint(true, false)
494+
#expect(try await df.count() == 10)
495+
}
496+
await spark.stop()
497+
}
498+
499+
@Test
500+
func localCheckpoint() async throws {
501+
let spark = try await SparkSession.builder.getOrCreate()
502+
if await spark.version >= "4.0.0" {
503+
#expect(try await spark.range(10).localCheckpoint().count() == 10)
504+
}
505+
await spark.stop()
506+
}
507+
484508
@Test
485509
func persist() async throws {
486510
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)