Skip to content

Commit 82bba31

Browse files
t-enstrivialfis
andauthored
Correct computation of metrics in multi-quantile regression. (dmlc#11279)
--------- Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 8c94e75 commit 82bba31

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

python-package/xgboost/testing/metrics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,21 @@ def check_quantile_error(tree_method: str) -> None:
7878
predt = booster.inplace_predict(X)
7979
loss = mean_pinball_loss(y, predt, alpha=0.3)
8080
np.testing.assert_allclose(evals_result["Train"]["quantile"][-1], loss)
81+
82+
alpha = [0.25, 0.5, 0.75]
83+
booster = xgb.train(
84+
{
85+
"tree_method": tree_method,
86+
"eval_metric": "quantile",
87+
"quantile_alpha": alpha,
88+
"objective": "reg:quantileerror",
89+
},
90+
Xy,
91+
evals=[(Xy, "Train")],
92+
evals_result=evals_result,
93+
)
94+
predt = booster.inplace_predict(X)
95+
loss = np.mean(
96+
[mean_pinball_loss(y, predt[:, i], alpha=alpha[i]) for i in range(3)]
97+
)
98+
np.testing.assert_allclose(evals_result["Train"]["quantile"][-1], loss)

src/metric/elementwise_metric.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ namespace {
4343
* applying the weights. A tuple of {error_i, weight_i} is expected as return.
4444
*/
4545
template <typename Fn>
46-
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
46+
PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss,
47+
size_t num_preds = 1) {
4748
PackedReduceResult result;
4849
// This function doesn't have sycl-specific implementation yet.
4950
// For that reason we transfer data to host in case of sycl is used for propper execution.
5051
auto labels = info.labels.View(ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device());
5152
if (ctx->IsCUDA()) {
5253
#if defined(XGBOOST_USE_CUDA)
5354
thrust::counting_iterator<size_t> begin(0);
54-
thrust::counting_iterator<size_t> end = begin + labels.Size();
55+
thrust::counting_iterator<size_t> end = begin + labels.Size() * num_preds;
5556
result = thrust::transform_reduce(
5657
ctx->CUDACtx()->CTP(), begin, end,
5758
[=] XGBOOST_DEVICE(size_t i) {
@@ -76,7 +77,7 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
7677
// - sqrt(1/w(sum_t0 + sum_t1 + ... + sum_tm)) // multi-target
7778
// - sqrt(avg_t0) + sqrt(avg_t1) + ... sqrt(avg_tm) // distributed
7879

79-
auto size = info.labels.Size();
80+
auto size = info.labels.Size() * num_preds;
8081
auto const kBlockSize = 2048;
8182
auto n_blocks = size / kBlockSize + 1;
8283

@@ -491,7 +492,7 @@ class QuantileError : public MetricNoCache {
491492
auto l =
492493
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
493494
return std::make_tuple(l, w);
494-
});
495+
}, alpha_.Size());
495496
std::array<double, 2> dat{result.Residue(), result.Weights()};
496497
auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(dat.data(), dat.size()));
497498
collective::SafeColl(rc);

tests/cpp/metric/test_elementwise_metric.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ inline void VerifyQuantile(DataSplitMode data_split_mode, DeviceOrd device) {
348348
std::unique_ptr<Metric> metric{Metric::Create("quantile", &ctx)};
349349

350350
HostDeviceVector<float> predts{0.1f, 0.9f, 0.1f, 0.9f};
351+
HostDeviceVector<float> predts_2{0.2f, 0.6f, 0.4f, 0.6f, 0.5f, 1.2f, 0.0f, 0.4f};
352+
HostDeviceVector<float> predts_3{0.2f, 0.4f, 0.6f,
353+
0.4f, 0.5f, 0.6f,
354+
0.5f, 0.8f, 1.2f,
355+
0.0f, 0.3f, 0.4f};
351356
std::vector<float> labels{0.5f, 0.5f, 0.9f, 0.1f};
352357
std::vector<float> weights{0.2f, 0.4f, 0.6f, 0.8f};
353358

@@ -377,5 +382,12 @@ inline void VerifyQuantile(DataSplitMode data_split_mode, DeviceOrd device) {
377382
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, {}, {}, data_split_mode), 0.3f, 0.001f);
378383
metric->Configure(Args{{"quantile_alpha", "[1.0]"}});
379384
EXPECT_NEAR(GetMetricEval(metric.get(), predts, labels, {}, {}, data_split_mode), 0.3f, 0.001f);
385+
386+
metric->Configure(Args{{"quantile_alpha", "[0.2, 0.8]"}});
387+
EXPECT_NEAR(GetMetricEval(metric.get(), predts_2, labels, {}, {}, data_split_mode), 0.0425f,
388+
0.0001f);
389+
metric->Configure(Args{{"quantile_alpha", "[0.2, 0.5, 0.8]"}});
390+
EXPECT_NEAR(GetMetricEval(metric.get(), predts_3, labels, {}, {}, data_split_mode), 0.0450f,
391+
0.0001f);
380392
}
381393
} // namespace xgboost::metric

0 commit comments

Comments
 (0)