diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index e7edfca..75f555f 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -151,6 +151,8 @@ import Synchronization /// /// ### Persistence /// - ``cache()`` +/// - ``checkpoint(_:_:_:)`` +/// - ``localCheckpoint(_:_:)`` /// - ``persist(storageLevel:)`` /// - ``unpersist(blocking:)`` /// - ``storageLevel`` @@ -1407,6 +1409,41 @@ public actor DataFrame: Sendable { try await spark.client.createTempView(self.plan.root, viewName, replace: replace, isGlobal: global) } + /// Eagerly checkpoint a ``DataFrame`` and return the new ``DataFrame``. + /// Checkpointing can be used to truncate the logical plan of this ``DataFrame``, + /// which is especially useful in iterative algorithms where the plan may grow exponentially. + /// It will be saved to files inside the checkpoint directory. + /// - Parameters: + /// - eager: Whether to checkpoint this dataframe immediately + /// - reliableCheckpoint: Whether to create a reliable checkpoint saved to files inside the checkpoint directory. + /// If false creates a local checkpoint using the caching subsystem + /// - storageLevel: StorageLevel with which to checkpoint the data. + /// - Returns: A ``DataFrame``. + public func checkpoint( + _ eager: Bool = true, + _ reliableCheckpoint: Bool = true, + _ storageLevel: StorageLevel? = nil + ) async throws -> DataFrame { + let plan = try await spark.client.getCheckpoint(self.plan.root, eager, reliableCheckpoint, storageLevel) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Locally checkpoints a ``DataFrame`` and return the new ``DataFrame``. + /// Checkpointing can be used to truncate the logical plan of this ``DataFrame``, + /// which is especially useful in iterative algorithms where the plan may grow exponentially. + /// Local checkpoints are written to executor storage and despite potentially faster they + /// are unreliable and may compromise job completion. + /// - Parameters: + /// - eager: Whether to checkpoint this dataframe immediately + /// - storageLevel: StorageLevel with which to checkpoint the data. + /// - Returns: A ``DataFrame``. + public func localCheckpoint( + _ eager: Bool = true, + _ storageLevel: StorageLevel? = nil + ) async throws -> DataFrame { + try await checkpoint(eager, false, storageLevel) + } + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { get { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 16c4f62..c69888f 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -996,6 +996,32 @@ public actor SparkConnectClient { return plan } + func getCheckpoint( + _ child: Relation, + _ eager: Bool, + _ reliableCheckpoint: Bool, + _ storageLevel: StorageLevel? + ) async throws -> Plan { + var checkpointCommand = Spark_Connect_CheckpointCommand() + checkpointCommand.eager = eager + checkpointCommand.local = !reliableCheckpoint + checkpointCommand.relation = child + if let storageLevel { + checkpointCommand.storageLevel = storageLevel.toSparkConnectStorageLevel + } + + var command = Spark_Connect_Command() + command.checkpointCommand = checkpointCommand + let response = try await execute(self.sessionID!, command) + let cachedRemoteRelation = response.first!.checkpointCommandResult.relation + + var relation = Relation() + relation.cachedRemoteRelation = cachedRemoteRelation + var plan = Plan() + plan.opType = .root(relation) + return plan + } + func createTempView( _ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool ) async throws { diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index aefc733..fd15496 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -481,6 +481,30 @@ struct DataFrameTests { await spark.stop() } + @Test + func checkpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + // By default, reliable checkpoint location is required. + try await #require(throws: Error.self) { + try await spark.range(10).checkpoint() + } + // Checkpointing with unreliable checkpoint + let df = try await spark.range(10).checkpoint(true, false) + #expect(try await df.count() == 10) + } + await spark.stop() + } + + @Test + func localCheckpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + #expect(try await spark.range(10).localCheckpoint().count() == 10) + } + await spark.stop() + } + @Test func persist() async throws { let spark = try await SparkSession.builder.getOrCreate()