Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit c5f9b89

Browse files
zhengruifengjkbradley
authored andcommitted
[SPARK-18608][ML] Fix double caching
## What changes were proposed in this pull request? `df.rdd.getStorageLevel` => `df.storageLevel` using cmd `find . -name '*.scala' | xargs -i bash -c 'egrep -in "\.rdd\.getStorageLevel" {} && echo {}'` to make sure all algs involved in this issue are fixed. Previous discussion in other PRs: apache#19107, apache#17014 ## How was this patch tested? existing tests Author: Zheng RuiFeng <[email protected]> Closes apache#19197 from zhengruifeng/double_caching.
1 parent b9b54b1 commit c5f9b89

File tree

6 files changed

+7
-7
lines changed

6 files changed

+7
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ class LogisticRegression @Since("1.2.0") (
484484
}
485485

486486
override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
487-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
487+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
488488
train(dataset, handlePersistence)
489489
}
490490

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ final class OneVsRestModel private[ml] (
165165
val newDataset = dataset.withColumn(accColName, initUDF())
166166

167167
// persist if underlying dataset is not persistent.
168-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
168+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
169169
if (handlePersistence) {
170170
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
171171
}
@@ -358,7 +358,7 @@ final class OneVsRest @Since("1.4.0") (
358358
}
359359

360360
// persist if underlying dataset is not persistent.
361-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
361+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
362362
if (handlePersistence) {
363363
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
364364
}

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ class KMeans @Since("1.5.0") (
304304
override def fit(dataset: Dataset[_]): KMeansModel = {
305305
transformSchema(dataset.schema, logging = true)
306306

307-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
307+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
308308
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
309309
case Row(point: Vector) => OldVectors.fromML(point)
310310
}

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
213213
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
214214
transformSchema(dataset.schema, logging = true)
215215
val instances = extractAFTPoints(dataset)
216-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
216+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
217217
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
218218

219219
val featuresSummarizer = {

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
165165
transformSchema(dataset.schema, logging = true)
166166
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
167167
val instances = extractWeightedLabeledPoints(dataset)
168-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
168+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
169169
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
170170

171171
val instr = Instrumentation.create(this, dataset)

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
251251
return lrModel
252252
}
253253

254-
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
254+
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
255255
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
256256

257257
val (featuresSummarizer, ySummarizer) = {

0 commit comments

Comments
 (0)