Skip to content

Commit baa832e

Browse files
authored
[backport] Fix NDCG metric with non-exp gain. (dmlc#11534) (dmlc#11578)
1 parent fe1596f commit baa832e

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

src/metric/rank_metric.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2020-2023 by XGBoost contributors
2+
* Copyright 2020-2025, XGBoost contributors
33
*/
44
#include "rank_metric.h"
55

@@ -10,7 +10,6 @@
1010
#include <array> // for array
1111
#include <cmath> // for log, sqrt
1212
#include <functional> // for less, greater
13-
#include <limits> // for numeric_limits
1413
#include <map> // for operator!=, _Rb_tree_const_iterator
1514
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
1615
#include <numeric> // for accumulate
@@ -22,7 +21,6 @@
2221
#include "../collective/aggregator.h" // for ApplyWithLabels
2322
#include "../common/algorithm.h" // for ArgSort, Sort
2423
#include "../common/linalg_op.h" // for cbegin, cend
25-
#include "../common/math.h" // for CmpFirst
2624
#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
2725
#include "dmlc/common.h" // for OMPException
2826
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
@@ -250,10 +248,6 @@ class EvalRankWithCache : public Metric {
250248
}
251249
param_.UpdateAllowUnknown(Args{});
252250
}
253-
void Configure(Args const&) override {
254-
// do not configure, otherwise the ndcg param will be forced into the same as the one in
255-
// objective.
256-
}
257251
void LoadConfig(Json const& in) override {
258252
if (IsA<Null>(in)) {
259253
return;
@@ -365,6 +359,18 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
365359
public:
366360
using EvalRankWithCache::EvalRankWithCache;
367361

362+
void Configure(Args const& args) override {
363+
// do not configure, otherwise the ndcg param like top-k will be forced into the same
364+
// as the one in objective. The metric has its own syntax for parameter.
365+
for (auto const& [key, value] : args) {
366+
// Make a special case for the exp gain parameter, which is not exposed in the
367+
// metric configuration syntax.
368+
if (key == "ndcg_exp_gain") {
369+
this->param_.UpdateAllowUnknown(Args{{key, value}});
370+
}
371+
}
372+
}
373+
368374
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
369375
std::shared_ptr<ltr::NDCGCache> p_cache) override {
370376
if (ctx_->IsCUDA()) {

tests/cpp/common/test_parameter.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
/*!
2-
* Copyright (c) by Contributors 2019
1+
/**
2+
* Copyright 2019-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
5-
65
#include <xgboost/base.h>
76
#include <xgboost/parameter.h>
87

8+
#include "xgboost/json.h" // for ToJson, FromJson
9+
910
enum class Foo : int {
1011
kBar = 0, kFrog = 1, kCat = 2, kDog = 3
1112
};
@@ -103,3 +104,16 @@ TEST(XGBoostParameter, Update) {
103104
a.UpdateAllowUnknown(xgboost::Args{{"f", "2.71828"}});
104105
ASSERT_NE(a.f, b.f);
105106
}
107+
namespace xgboost {
108+
TEST(XGBoostParameter, Json) {
109+
UpdatableParam a, b;
110+
a.UpdateAllowUnknown(Args{{"f", "1024"}, {"d", "2048"}});
111+
auto ja = Json{ToJson(a)};
112+
113+
UpdatableParam c;
114+
FromJson(ja, &c);
115+
ASSERT_FLOAT_EQ(a.f, 1024);
116+
ASSERT_FLOAT_EQ(c.f, 1024);
117+
ASSERT_FLOAT_EQ(b.f, 0); // Make sure dmlc global variable is not used here.
118+
}
119+
} // namespace xgboost

tests/python/test_ranking.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ def ndcg_gain(y: np.ndarray) -> np.ndarray:
6969
)
7070

7171

72+
def test_ndcg_non_exp() -> None:
73+
# NDCG exp gain must have label smaller than 32
74+
X, y, q, w = tm.make_ltr(n_samples=1024, n_features=4, n_query_groups=3, max_rel=44)
75+
76+
def fit(ltr: xgboost.XGBRanker):
77+
ltr.fit(
78+
X,
79+
y,
80+
qid=q,
81+
sample_weight=w,
82+
eval_set=[(X, y)],
83+
eval_qid=(q,),
84+
sample_weight_eval_set=(w,),
85+
)
86+
87+
ltr = xgboost.XGBRanker(tree_method="hist", ndcg_exp_gain=True, n_estimators=2)
88+
with pytest.raises(ValueError, match="Set `ndcg_exp_gain`"):
89+
fit(ltr)
90+
ltr = xgboost.XGBRanker(tree_method="hist", ndcg_exp_gain=False, n_estimators=2)
91+
fit(ltr)
92+
93+
7294
def test_ranking_with_unweighted_data():
7395
Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17])
7496
Xcol = np.array([0, 0, 1, 1, 2, 2, 3, 3])

0 commit comments

Comments
 (0)