Skip to content

Commit 1bb60ab

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-26153][ML] GBT & RandomForest avoid unnecessary first job to compute numFeatures
## What changes were proposed in this pull request? use base models' `numFeature` instead of `first` job ## How was this patch tested? existing tests Closes apache#23123 from zhengruifeng/avoid_first_job. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 6bb60b3 commit 1bb60ab

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") (
180180
(convert2LabeledPoint(dataset), null)
181181
}
182182

183-
val numFeatures = trainDataset.first().features.size
184183
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
185184

186185
val numClasses = 2
@@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") (
196195
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
197196
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
198197
validationIndicatorCol)
199-
instr.logNumFeatures(numFeatures)
200198
instr.logNumClasses(numClasses)
201199

202200
val (baseLearners, learnerWeights) = if (withValidation) {
@@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") (
206204
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
207205
}
208206

207+
val numFeatures = baseLearners.head.numFeatures
208+
instr.logNumFeatures(numFeatures)
209+
209210
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
210211
}
211212

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") (
142142
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
143143
.map(_.asInstanceOf[DecisionTreeClassificationModel])
144144

145-
val numFeatures = oldDataset.first().features.size
145+
val numFeatures = trees.head.numFeatures
146146
instr.logNumClasses(numClasses)
147147
instr.logNumFeatures(numFeatures)
148148
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,13 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
165165
} else {
166166
(extractLabeledPoints(dataset), null)
167167
}
168-
val numFeatures = trainDataset.first().features.size
169168
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
170169

171170
instr.logPipelineStage(this)
172171
instr.logDataset(dataset)
173172
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType,
174173
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
175174
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
176-
instr.logNumFeatures(numFeatures)
177175

178176
val (baseLearners, learnerWeights) = if (withValidation) {
179177
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
@@ -182,6 +180,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
182180
GradientBoostedTrees.run(trainDataset, boostingStrategy,
183181
$(seed), $(featureSubsetStrategy))
184182
}
183+
184+
val numFeatures = baseLearners.head.numFeatures
185+
instr.logNumFeatures(numFeatures)
186+
185187
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
186188
}
187189

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
133133
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
134134
.map(_.asInstanceOf[DecisionTreeRegressionModel])
135135

136-
val numFeatures = oldDataset.first().features.size
136+
val numFeatures = trees.head.numFeatures
137137
instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
138138
new RandomForestRegressionModel(uid, trees, numFeatures)
139139
}

0 commit comments

Comments
 (0)