Skip to content

Commit bdc5a26

Browse files
authored
Fix init with scale pos weight. (dmlc#11280)
1 parent be83eb6 commit bdc5a26

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

src/objective/init_estimation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class FitIntercept : public ObjFunction {
1313
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override;
1414
};
1515

16-
class FitInterceptGlmLike : public ObjFunction {
16+
class FitInterceptGlmLike : public FitIntercept {
1717
public:
1818
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override;
1919
};

src/objective/regression_obj.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,16 @@ class RegLossObj : public FitInterceptGlmLike {
187187
.Eval(io_preds);
188188
}
189189

190+
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override {
191+
if (std::abs(this->param_.scale_pos_weight - 1.0f) > kRtEps) {
192+
// Use newton method if `scale_pos_weight` is present. The alternative is to use
193+
// weighted mean, but we also need to take sample weight into account.
194+
FitIntercept::InitEstimation(info, base_score);
195+
} else {
196+
FitInterceptGlmLike::InitEstimation(info, base_score);
197+
}
198+
}
199+
190200
[[nodiscard]] float ProbToMargin(float base_score) const override {
191201
return Loss::ProbToMargin(base_score);
192202
}

tests/python/test_objectives.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ def test_exp_family() -> None:
1818
)
1919
# The base score stored in the booster model is un-transformed
2020
np.testing.assert_allclose([get_basescore(m) for m in (reg, clf, clf1)], y.mean())
21+
22+
X, y = make_classification(weights=[0.8, 0.2], random_state=2025)
23+
clf = xgb.train(
24+
{"objective": "binary:logistic", "scale_pos_weight": 4.0},
25+
xgb.QuantileDMatrix(X, y),
26+
num_boost_round=1,
27+
)
28+
score = get_basescore(clf)
29+
np.testing.assert_allclose(score, 0.5, rtol=1e-3)

0 commit comments

Comments
 (0)