Skip to content

Commit 83b96f6

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-28117][ML] LDA and BisectingKMeans cache the input dataset if necessary
## What changes were proposed in this pull request? cache dataset in BisectingKMeans cache dataset in LDA if Online solver is chosen. ## How was this patch tested? existing test Closes apache#24920 from zhengruifeng/bikm_cache. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent c397b06 commit 83b96f6

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
3232
import org.apache.spark.sql.{DataFrame, Dataset}
3333
import org.apache.spark.sql.functions.udf
3434
import org.apache.spark.sql.types.{IntegerType, StructType}
35+
import org.apache.spark.storage.StorageLevel
3536

3637

3738
/**
@@ -248,7 +249,12 @@ class BisectingKMeans @Since("2.0.0") (
248249
@Since("2.0.0")
249250
override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr =>
250251
transformSchema(dataset.schema, logging = true)
252+
253+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
251254
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
255+
if (handlePersistence) {
256+
rdd.persist(StorageLevel.MEMORY_AND_DISK)
257+
}
252258

253259
instr.logPipelineStage(this)
254260
instr.logDataset(dataset)
@@ -263,6 +269,10 @@ class BisectingKMeans @Since("2.0.0") (
263269
.setDistanceMeasure($(distanceMeasure))
264270
val parentModel = bkm.run(rdd, Some(instr))
265271
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
272+
if (handlePersistence) {
273+
rdd.unpersist()
274+
}
275+
266276
val summary = new BisectingKMeansSummary(
267277
model.transform(dataset),
268278
$(predictionCol),

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ import org.apache.spark.mllib.util.MLUtils
4444
import org.apache.spark.rdd.RDD
4545
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
4646
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
47-
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType}
47+
import org.apache.spark.sql.types.StructType
48+
import org.apache.spark.storage.StorageLevel
4849
import org.apache.spark.util.PeriodicCheckpointer
4950
import org.apache.spark.util.VersionUtils
5051

@@ -904,6 +905,18 @@ class LDA @Since("1.6.0") (
904905
checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration,
905906
learningDecay, optimizer, learningOffset, seed)
906907

908+
val oldData = LDA.getOldDataset(dataset, $(featuresCol))
909+
910+
// The EM solver will transform this oldData to a graph, and use a internal graphCheckpointer
911+
// to update and cache the graph, so we do not need to cache it.
912+
// The Online solver directly perform sampling on the oldData and update the model.
913+
// However, Online solver will not cache the dataset internally.
914+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE &&
915+
getOptimizer.toLowerCase(Locale.ROOT) == "online"
916+
if (handlePersistence) {
917+
oldData.persist(StorageLevel.MEMORY_AND_DISK)
918+
}
919+
907920
val oldLDA = new OldLDA()
908921
.setK($(k))
909922
.setDocConcentration(getOldDocConcentration)
@@ -912,15 +925,17 @@ class LDA @Since("1.6.0") (
912925
.setSeed($(seed))
913926
.setCheckpointInterval($(checkpointInterval))
914927
.setOptimizer(getOldOptimizer)
915-
// TODO: persist here, or in old LDA?
916-
val oldData = LDA.getOldDataset(dataset, $(featuresCol))
928+
917929
val oldModel = oldLDA.run(oldData)
918930
val newModel = oldModel match {
919931
case m: OldLocalLDAModel =>
920932
new LocalLDAModel(uid, m.vocabSize, m, dataset.sparkSession)
921933
case m: OldDistributedLDAModel =>
922934
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sparkSession, None)
923935
}
936+
if (handlePersistence) {
937+
oldData.unpersist()
938+
}
924939

925940
instr.logNumFeatures(newModel.vocabSize)
926941
copyValues(newModel).setParent(this)
@@ -940,8 +955,8 @@ object LDA extends MLReadable[LDA] {
940955
dataset: Dataset[_],
941956
featuresCol: String): RDD[(Long, OldVector)] = {
942957
dataset
943-
.withColumn("docId", monotonically_increasing_id())
944-
.select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol))
958+
.select(monotonically_increasing_id(),
959+
DatasetUtils.columnToVector(dataset, featuresCol))
945960
.rdd
946961
.map { case Row(docId: Long, features: Vector) =>
947962
(docId, OldVectors.fromML(features))

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
3030
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.storage.StorageLevel
33-
import org.apache.spark.util.Utils
33+
3434

3535
/**
3636
* :: DeveloperApi ::
@@ -437,6 +437,10 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
437437
this.randomGenerator = new Random(lda.getSeed)
438438

439439
this.docs = docs
440+
if (this.docs.getStorageLevel == StorageLevel.NONE) {
441+
logWarning("The input data is not directly cached, which may hurt performance if its"
442+
+ " parent RDDs are also uncached.")
443+
}
440444

441445
// Initialize the variational distribution q(beta|lambda)
442446
this.lambda = getGammaMatrix(k, vocabSize)

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
7171
rows =>
7272
val numClusters = rows.distinct.length
7373
// Verify we hit the edge case
74-
assert(numClusters < k && numClusters > 1)
74+
assert(numClusters > 1)
7575
}
7676
}
7777

0 commit comments

Comments
 (0)