From e2ad6b3edfcc4a750ce5234c6ed20ac0da48c6f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 16:09:37 -0700 Subject: [PATCH 1/2] [SPARK-52172] Add `checkpoint` and `localCheckpoint` for `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 37 +++++++++++++++++++ Sources/SparkConnect/SparkConnectClient.swift | 26 +++++++++++++ Tests/SparkConnectTests/DataFrameTests.swift | 20 ++++++++++ 3 files changed, 83 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index e7edfca..75f555f 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -151,6 +151,8 @@ import Synchronization /// /// ### Persistence /// - ``cache()`` +/// - ``checkpoint(_:_:_:)`` +/// - ``localCheckpoint(_:_:)`` /// - ``persist(storageLevel:)`` /// - ``unpersist(blocking:)`` /// - ``storageLevel`` @@ -1407,6 +1409,41 @@ public actor DataFrame: Sendable { try await spark.client.createTempView(self.plan.root, viewName, replace: replace, isGlobal: global) } + /// Eagerly checkpoint a ``DataFrame`` and return the new ``DataFrame``. + /// Checkpointing can be used to truncate the logical plan of this ``DataFrame``, + /// which is especially useful in iterative algorithms where the plan may grow exponentially. + /// It will be saved to files inside the checkpoint directory. + /// - Parameters: + /// - eager: Whether to checkpoint this dataframe immediately + /// - reliableCheckpoint: Whether to create a reliable checkpoint saved to files inside the checkpoint directory. + /// If false creates a local checkpoint using the caching subsystem + /// - storageLevel: StorageLevel with which to checkpoint the data. + /// - Returns: A ``DataFrame``. + public func checkpoint( + _ eager: Bool = true, + _ reliableCheckpoint: Bool = true, + _ storageLevel: StorageLevel? = nil + ) async throws -> DataFrame { + let plan = try await spark.client.getCheckpoint(self.plan.root, eager, reliableCheckpoint, storageLevel) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Locally checkpoints a ``DataFrame`` and return the new ``DataFrame``. + /// Checkpointing can be used to truncate the logical plan of this ``DataFrame``, + /// which is especially useful in iterative algorithms where the plan may grow exponentially. + /// Local checkpoints are written to executor storage and despite potentially faster they + /// are unreliable and may compromise job completion. + /// - Parameters: + /// - eager: Whether to checkpoint this dataframe immediately + /// - storageLevel: StorageLevel with which to checkpoint the data. + /// - Returns: A ``DataFrame``. + public func localCheckpoint( + _ eager: Bool = true, + _ storageLevel: StorageLevel? = nil + ) async throws -> DataFrame { + try await checkpoint(eager, false, storageLevel) + } + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { get { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 16c4f62..c69888f 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -996,6 +996,32 @@ public actor SparkConnectClient { return plan } + func getCheckpoint( + _ child: Relation, + _ eager: Bool, + _ reliableCheckpoint: Bool, + _ storageLevel: StorageLevel? + ) async throws -> Plan { + var checkpointCommand = Spark_Connect_CheckpointCommand() + checkpointCommand.eager = eager + checkpointCommand.local = !reliableCheckpoint + checkpointCommand.relation = child + if let storageLevel { + checkpointCommand.storageLevel = storageLevel.toSparkConnectStorageLevel + } + + var command = Spark_Connect_Command() + command.checkpointCommand = checkpointCommand + let response = try await execute(self.sessionID!, command) + let cachedRemoteRelation = response.first!.checkpointCommandResult.relation + + var relation = Relation() + relation.cachedRemoteRelation = cachedRemoteRelation + var plan = Plan() + plan.opType = .root(relation) + return plan + } + func createTempView( _ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool ) async throws { diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index aefc733..32b0f17 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -481,6 +481,26 @@ struct DataFrameTests { await spark.stop() } + @Test + func checkpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + // By default, reliable checkpoint location is required. + try await #require(throws: Error.self) { + try await spark.range(10).checkpoint() + } + // Checkpointing with unreliable checkpoint + let df = try await spark.range(10).checkpoint(true, false) + #expect(try await df.count() == 10) + await spark.stop() + } + + @Test + func localCheckpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).localCheckpoint().count() == 10) + await spark.stop() + } + @Test func persist() async throws { let spark = try await SparkSession.builder.getOrCreate() From 1024edb67781b8991c2d17605eac51ec76e559f3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 17:26:03 -0700 Subject: [PATCH 2/2] skip for Spark 3 --- Tests/SparkConnectTests/DataFrameTests.swift | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 32b0f17..fd15496 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -484,20 +484,24 @@ struct DataFrameTests { @Test func checkpoint() async throws { let spark = try await SparkSession.builder.getOrCreate() - // By default, reliable checkpoint location is required. - try await #require(throws: Error.self) { - try await spark.range(10).checkpoint() + if await spark.version >= "4.0.0" { + // By default, reliable checkpoint location is required. + try await #require(throws: Error.self) { + try await spark.range(10).checkpoint() + } + // Checkpointing with unreliable checkpoint + let df = try await spark.range(10).checkpoint(true, false) + #expect(try await df.count() == 10) } - // Checkpointing with unreliable checkpoint - let df = try await spark.range(10).checkpoint(true, false) - #expect(try await df.count() == 10) await spark.stop() } @Test func localCheckpoint() async throws { let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(10).localCheckpoint().count() == 10) + if await spark.version >= "4.0.0" { + #expect(try await spark.range(10).localCheckpoint().count() == 10) + } await spark.stop() }