Skip to content

Commit b3b62ba

Browse files
imatiach-msftsrowen
authored andcommitted
[SPARK-19591][ML][MLLIB][FOLLOWUP] Add sample weights to decision trees - fix tolerance
This is a follow-up to PR: apache#21632 ## What changes were proposed in this pull request? This PR tunes the tolerance used for deciding whether to add zero feature values to a value-count map (where the key is the feature value and the value is the weighted count of those feature values). In the previous PR the tolerance scaled by the square of the unweighted number of samples, which is too aggressive for a large number of unweighted samples. Unfortunately using just "Utils.EPSILON * unweightedNumSamples" is not enough either, so I multiplied that by a factor tuned by the testing procedure below. ## How was this patch tested? This involved manually running the sample weight tests for decision tree regressor to see whether the tolerance was large enough to exclude zero feature values. Eg in SBT: ``` ./build/sbt > project mllib > testOnly *DecisionTreeRegressorSuite -- -z "training with sample weights" ``` For validation, I added a print inside the if in the code below and validated that the tolerance was large enough so that we would not include zero features (which don't exist in that test): ``` val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) { print("should not print this") partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples)) } else { partValueCountMap } ``` Closes apache#23682 from imatiach-msft/ilmat/sample-weights-tol. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent bc6f191 commit b3b62ba

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,8 +1050,11 @@ private[spark] object RandomForest extends Logging with Serializable {
10501050
// Calculate the expected number of samples for finding splits
10511051
val weightedNumSamples = samplesFractionForFindSplits(metadata) *
10521052
metadata.weightedNumExamples
1053+
// scale tolerance by number of samples with constant factor
1054+
// Note: constant factor was tuned by running some tests where there were no zero
1055+
// feature values and validating we are never within tolerance
1056+
val tolerance = Utils.EPSILON * unweightedNumSamples * 100
10531057
// add expected zero value count and get complete statistics
1054-
val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
10551058
val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
10561059
partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
10571060
} else {

0 commit comments

Comments
 (0)