Skip to content

Commit c83b3dd

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-28154][ML][FOLLOWUP] GMM fix double caching
## What changes were proposed in this pull request? if the input dataset is alreadly cached, then we do not need to cache the internal rdd (like kmeans) ## How was this patch tested? existing test Closes apache#24919 from zhengruifeng/gmm_fix_double_caching. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 83b96f6 commit c83b3dd

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD
3636
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
3737
import org.apache.spark.sql.functions.udf
3838
import org.apache.spark.sql.types.{IntegerType, StructType}
39+
import org.apache.spark.storage.StorageLevel
3940

4041

4142
/**
@@ -330,10 +331,15 @@ class GaussianMixture @Since("2.0.0") (
330331
val sc = dataset.sparkSession.sparkContext
331332
val numClusters = $(k)
332333

334+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
333335
val instances = dataset
334336
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
335337
case Row(features: Vector) => features
336-
}.cache()
338+
}
339+
340+
if (handlePersistence) {
341+
instances.persist(StorageLevel.MEMORY_AND_DISK)
342+
}
337343

338344
// Extract the number of features.
339345
val numFeatures = instances.first().size
@@ -410,8 +416,10 @@ class GaussianMixture @Since("2.0.0") (
410416
logLikelihood = sums.logLikelihood // this is the freshly computed log-likelihood
411417
iter += 1
412418
}
419+
if (handlePersistence) {
420+
instances.unpersist()
421+
}
413422

414-
instances.unpersist()
415423
val gaussianDists = gaussians.map { case (mean, covVec) =>
416424
val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values)
417425
new MultivariateGaussian(mean, cov)

0 commit comments

Comments
 (0)