Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 13268a5

Browse files
ferdonlinegatorsmile
authored andcommitted
[SPARK-22649][PYTHON][SQL] Adding localCheckpoint to Dataset API
## What changes were proposed in this pull request? This change adds local checkpoint support to datasets and respective bind from Python Dataframe API. If reliability requirements can be lowered to favor performance, as in cases of further quick transformations followed by a reliable save, localCheckpoints() fit very well. Furthermore, at the moment Reliable checkpoints still incur double computation (see apache#9428) In general it makes the API more complete as well. ## How was this patch tested? Python land quick use case: ```python >>> from time import sleep >>> from pyspark.sql import types as T >>> from pyspark.sql import functions as F >>> def f(x): sleep(1) return x*2 ...: >>> df1 = spark.range(30, numPartitions=6) >>> df2 = df1.select(F.udf(f, T.LongType())("id")) >>> %time _ = df2.collect() CPU times: user 7.79 ms, sys: 5.84 ms, total: 13.6 ms Wall time: 12.2 s >>> %time df3 = df2.localCheckpoint() CPU times: user 2.38 ms, sys: 2.3 ms, total: 4.68 ms Wall time: 10.3 s >>> %time _ = df3.collect() CPU times: user 5.09 ms, sys: 410 µs, total: 5.5 ms Wall time: 148 ms >>> sc.setCheckpointDir(".") >>> %time df3 = df2.checkpoint() CPU times: user 4.04 ms, sys: 1.63 ms, total: 5.67 ms Wall time: 20.3 s ``` Author: Fernando Pereira <[email protected]> Closes apache#19805 from ferdonline/feature_dataset_localCheckpoint.
1 parent 6e36d8d commit 13268a5

File tree

3 files changed

+121
-49
lines changed

3 files changed

+121
-49
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,20 @@ def checkpoint(self, eager=True):
368368
jdf = self._jdf.checkpoint(eager)
369369
return DataFrame(jdf, self.sql_ctx)
370370

371+
@since(2.3)
372+
def localCheckpoint(self, eager=True):
373+
"""Returns a locally checkpointed version of this Dataset. Checkpointing can be used to
374+
truncate the logical plan of this DataFrame, which is especially useful in iterative
375+
algorithms where the plan may grow exponentially. Local checkpoints are stored in the
376+
executors using the caching subsystem and therefore they are not reliable.
377+
378+
:param eager: Whether to checkpoint this DataFrame immediately
379+
380+
.. note:: Experimental
381+
"""
382+
jdf = self._jdf.localCheckpoint(eager)
383+
return DataFrame(jdf, self.sql_ctx)
384+
371385
@since(2.1)
372386
def withWatermark(self, eventTime, delayThreshold):
373387
"""Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ class Dataset[T] private[sql](
527527
*/
528528
@Experimental
529529
@InterfaceStability.Evolving
530-
def checkpoint(): Dataset[T] = checkpoint(eager = true)
530+
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
531531

532532
/**
533533
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
@@ -540,9 +540,52 @@ class Dataset[T] private[sql](
540540
*/
541541
@Experimental
542542
@InterfaceStability.Evolving
543-
def checkpoint(eager: Boolean): Dataset[T] = {
543+
def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true)
544+
545+
/**
546+
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be
547+
* used to truncate the logical plan of this Dataset, which is especially useful in iterative
548+
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
549+
* storage and despite potentially faster they are unreliable and may compromise job completion.
550+
*
551+
* @group basic
552+
* @since 2.3.0
553+
*/
554+
@Experimental
555+
@InterfaceStability.Evolving
556+
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
557+
558+
/**
559+
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate
560+
* the logical plan of this Dataset, which is especially useful in iterative algorithms where the
561+
* plan may grow exponentially. Local checkpoints are written to executor storage and despite
562+
* potentially faster they are unreliable and may compromise job completion.
563+
*
564+
* @group basic
565+
* @since 2.3.0
566+
*/
567+
@Experimental
568+
@InterfaceStability.Evolving
569+
def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
570+
eager = eager,
571+
reliableCheckpoint = false
572+
)
573+
574+
/**
575+
* Returns a checkpointed version of this Dataset.
576+
*
577+
* @param eager Whether to checkpoint this dataframe immediately
578+
* @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the
579+
* checkpoint directory. If false creates a local checkpoint using
580+
* the caching subsystem
581+
*/
582+
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
544583
val internalRdd = queryExecution.toRdd.map(_.copy())
545-
internalRdd.checkpoint()
584+
if (reliableCheckpoint) {
585+
internalRdd.checkpoint()
586+
} else {
587+
internalRdd.localCheckpoint()
588+
}
546589

547590
if (eager) {
548591
internalRdd.count()

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,67 +1156,82 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
11561156
}
11571157

11581158
Seq(true, false).foreach { eager =>
1159-
def testCheckpointing(testName: String)(f: => Unit): Unit = {
1160-
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
1161-
withTempDir { dir =>
1162-
val originalCheckpointDir = spark.sparkContext.checkpointDir
1163-
1164-
try {
1165-
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
1159+
Seq(true, false).foreach { reliable =>
1160+
def testCheckpointing(testName: String)(f: => Unit): Unit = {
1161+
test(s"Dataset.checkpoint() - $testName (eager = $eager, reliable = $reliable)") {
1162+
if (reliable) {
1163+
withTempDir { dir =>
1164+
val originalCheckpointDir = spark.sparkContext.checkpointDir
1165+
1166+
try {
1167+
spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
1168+
f
1169+
} finally {
1170+
// Since the original checkpointDir can be None, we need
1171+
// to set the variable directly.
1172+
spark.sparkContext.checkpointDir = originalCheckpointDir
1173+
}
1174+
}
1175+
} else {
1176+
// Local checkpoints dont require checkpoint_dir
11661177
f
1167-
} finally {
1168-
// Since the original checkpointDir can be None, we need
1169-
// to set the variable directly.
1170-
spark.sparkContext.checkpointDir = originalCheckpointDir
11711178
}
11721179
}
11731180
}
1174-
}
11751181

1176-
testCheckpointing("basic") {
1177-
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
1178-
val cp = ds.checkpoint(eager)
1182+
testCheckpointing("basic") {
1183+
val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
1184+
val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)
11791185

1180-
val logicalRDD = cp.logicalPlan match {
1181-
case plan: LogicalRDD => plan
1182-
case _ =>
1183-
val treeString = cp.logicalPlan.treeString(verbose = true)
1184-
fail(s"Expecting a LogicalRDD, but got\n$treeString")
1185-
}
1186+
val logicalRDD = cp.logicalPlan match {
1187+
case plan: LogicalRDD => plan
1188+
case _ =>
1189+
val treeString = cp.logicalPlan.treeString(verbose = true)
1190+
fail(s"Expecting a LogicalRDD, but got\n$treeString")
1191+
}
11861192

1187-
val dsPhysicalPlan = ds.queryExecution.executedPlan
1188-
val cpPhysicalPlan = cp.queryExecution.executedPlan
1193+
val dsPhysicalPlan = ds.queryExecution.executedPlan
1194+
val cpPhysicalPlan = cp.queryExecution.executedPlan
11891195

1190-
assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
1191-
assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
1196+
assertResult(dsPhysicalPlan.outputPartitioning) {
1197+
logicalRDD.outputPartitioning
1198+
}
1199+
assertResult(dsPhysicalPlan.outputOrdering) {
1200+
logicalRDD.outputOrdering
1201+
}
11921202

1193-
assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
1194-
assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
1203+
assertResult(dsPhysicalPlan.outputPartitioning) {
1204+
cpPhysicalPlan.outputPartitioning
1205+
}
1206+
assertResult(dsPhysicalPlan.outputOrdering) {
1207+
cpPhysicalPlan.outputOrdering
1208+
}
11951209

1196-
// For a lazy checkpoint() call, the first check also materializes the checkpoint.
1197-
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
1210+
// For a lazy checkpoint() call, the first check also materializes the checkpoint.
1211+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
11981212

1199-
// Reads back from checkpointed data and check again.
1200-
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
1201-
}
1213+
// Reads back from checkpointed data and check again.
1214+
checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
1215+
}
12021216

1203-
testCheckpointing("should preserve partitioning information") {
1204-
val ds = spark.range(10).repartition('id % 2)
1205-
val cp = ds.checkpoint(eager)
1217+
testCheckpointing("should preserve partitioning information") {
1218+
val ds = spark.range(10).repartition('id % 2)
1219+
val cp = if (reliable) ds.checkpoint(eager) else ds.localCheckpoint(eager)
12061220

1207-
val agg = cp.groupBy('id % 2).agg(count('id))
1221+
val agg = cp.groupBy('id % 2).agg(count('id))
12081222

1209-
agg.queryExecution.executedPlan.collectFirst {
1210-
case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
1211-
case BroadcastExchangeExec(_, _: RDDScanExec) =>
1212-
}.foreach { _ =>
1213-
fail(
1214-
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
1215-
"preserves partitioning information:\n\n" + agg.queryExecution
1216-
)
1217-
}
1223+
agg.queryExecution.executedPlan.collectFirst {
1224+
case ShuffleExchangeExec(_, _: RDDScanExec, _) =>
1225+
case BroadcastExchangeExec(_, _: RDDScanExec) =>
1226+
}.foreach { _ =>
1227+
fail(
1228+
"No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
1229+
"preserves partitioning information:\n\n" + agg.queryExecution
1230+
)
1231+
}
12181232

1219-
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
1233+
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
1234+
}
12201235
}
12211236
}
12221237

0 commit comments

Comments
 (0)