Skip to content

Commit b14e89d

Browse files
rstzcopybara-github
authored andcommitted
[YDF] Fix computation in LabelHessianNumericalBucket
PiperOrigin-RevId: 869726553
1 parent 0719492 commit b14e89d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

yggdrasil_decision_forests/learner/decision_tree/splitter_accumulator.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,15 +1706,17 @@ struct LabelHessianNumericalBucket {
17061706
if constexpr (weighted) {
17071707
acc->Add(content.sum_gradient, content.sum_hessian, content.sum_weight);
17081708
} else {
1709-
acc->Add(content.sum_gradient, content.sum_hessian, 1.f);
1709+
acc->Add(content.sum_gradient, content.sum_hessian,
1710+
static_cast<float>(count));
17101711
}
17111712
}
17121713

17131714
void SubToScoreAcc(LabelHessianNumericalScoreAccumulator* acc) const {
17141715
if constexpr (weighted) {
17151716
acc->Sub(content.sum_gradient, content.sum_hessian, content.sum_weight);
17161717
} else {
1717-
acc->Sub(content.sum_gradient, content.sum_hessian, 1.f);
1718+
acc->Sub(content.sum_gradient, content.sum_hessian,
1719+
static_cast<float>(count));
17181720
}
17191721
}
17201722

0 commit comments

Comments
 (0)