Skip to content

Commit be83eb6

Browse files
authored
Optional normalization for the ranknet loss. (dmlc#11272)
1 parent 337ee78 commit be83eb6

File tree

11 files changed

+244
-80
lines changed

11 files changed

+244
-80
lines changed

R-package/R/xgb.train.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,21 @@ xgb.train <- function(params = xgb.params(), data, nrounds, evals = list(),
769769
#' Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress.
770770
#'
771771
#' Version added: 2.1.0
772+
#'
773+
#' @param lambdarank_score_normalization
774+
#'
775+
#' Whether to normalize the delta metric by the difference of prediction scores. This can
776+
#' sometimes stagnate the training progress. With pairwise ranking, we can normalize the
777+
#' gradient using the difference between two samples in each pair to reduce influence from
778+
#' the pairs that have large difference in ranking scores. This can help us regularize the
779+
#' model to reduce bias and prevent overfitting. Similar to other regularization
780+
#' techniques, this might prevent training from converging.
781+
#'
782+
#' There was no normalization before 2.0. In 2.0 and later versions this is used by
783+
#' default. In 3.0, we made this an option that users can disable.
784+
#'
785+
#' Version added: 3.0.0
786+
#'
772787
#' @param lambdarank_unbiased (for learning to rank (`"rank:ndcg"`, `"rank:map"`, `"rank:pairwise"`)) (default = `FALSE`)
773788
#' Specify whether do we need to debias input click data.
774789
#' @param lambdarank_bias_norm (for learning to rank (`"rank:ndcg"`, `"rank:map"`, `"rank:pairwise"`)) (default = 2.0)
@@ -833,6 +848,7 @@ xgb.params <- function(
833848
lambdarank_pair_method = NULL,
834849
lambdarank_num_pair_per_sample = NULL,
835850
lambdarank_normalization = NULL,
851+
lambdarank_score_normalization = NULL,
836852
lambdarank_unbiased = NULL,
837853
lambdarank_bias_norm = NULL,
838854
ndcg_exp_gain = NULL

R-package/man/xgb.params.Rd

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

doc/parameter.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,20 @@ These are parameters specific to learning to rank task. See :doc:`Learning to Ra
540540

541541
Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress.
542542

543+
* ``lambdarank_score_normalization`` [default = ``true``]
544+
545+
.. versionadded:: 3.0.0
546+
547+
Whether to normalize the delta metric by the difference of prediction scores. This can
548+
sometimes stagnate the training progress. With pairwise ranking, we can normalize the
549+
gradient using the difference between two samples in each pair to reduce influence from
550+
the pairs that have large difference in ranking scores. This can help us regularize the
551+
model to reduce bias and prevent overfitting. Similar to other regularization
552+
techniques, this might prevent training from converging.
553+
554+
There was no normalization before 2.0. In 2.0 and later versions this is used by
555+
default. In 3.0, we made this an option that users can disable.
556+
543557
* ``lambdarank_unbiased`` [default = ``false``]
544558

545559
Specify whether do we need to debias input click data.

doc/tutorials/learning_to_rank.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ For a longer explanation, assuming the pairwise ranking method is used, we calcu
186186

187187
However, it's possible that a distributed framework shuffles the data during map reduce and splits every query group into multiple workers. In that case, the performance would be disastrous. As a result, it depends on the data and the framework for whether a sorted groupby is needed.
188188

189+
**********************************
190+
Comparing Results with Version 1.7
191+
**********************************
192+
193+
The learning to rank implementation has been significantly updated in 2.0 with added hyper-parameters and training strategies. To obtain similar result as the 1.7 :py:class:`xgboost.XGBRanker`, following parameter should be used:
194+
195+
.. code-block:: python
196+
197+
params = {
198+
# 1.7 only supports sampling, while 2.0 and later use top-k as the default.
199+
# See above sections for the trade-off.
200+
"lambdarank_pair_method": "mean",
201+
# Normalization was added in 2.0
202+
"lambdarank_normalization": False,
203+
# 1.7 uses the ranknet loss while later versions use the NDCG weighted loss
204+
"objective": "rank:pairwise",
205+
# 1.7 doesn't have this normalization.
206+
"lambdarank_score_normalization": False,
207+
"base_score": 0.5,
208+
# The default tree method has been changed from approx to hist.
209+
"tree_method": "approx",
210+
# The default for `mean` pair method is one pair each sample, which is the default in 1.7 as well.
211+
# You can leave it as unset.
212+
"lambdarank_num_pair_per_sample": 1,
213+
}
214+
215+
The result still differs due to the change of random seed. But the overall training strategy would be the same for ``rank:pairwise``.
216+
189217
*******************
190218
Reproducible Result
191219
*******************

python-package/xgboost/testing/ranking.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,30 @@ def run_normalization(device: str) -> None:
118118
ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid])
119119
e1 = ltr.evals_result()
120120
assert e1["validation_0"]["ndcg@32"][-1] > e0["validation_0"]["ndcg@32"][-1]
121+
122+
123+
def run_score_normalization(device: str, objective: str) -> None:
124+
"""Test normalization by score differences."""
125+
if objective == "rank:map":
126+
# Binary relevance
127+
X, y, qid, _ = tm.make_ltr(4096, 4, 64, max_rel=1)
128+
else:
129+
X, y, qid, _ = tm.make_ltr(4096, 4, 64, 3)
130+
ltr = xgb.XGBRanker(objective=objective, n_estimators=4, device=device)
131+
ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid])
132+
e0 = ltr.evals_result()
133+
134+
ltr = xgb.XGBRanker(
135+
objective="rank:pairwise",
136+
n_estimators=4,
137+
device=device,
138+
lambdarank_score_normalization=False,
139+
)
140+
ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid])
141+
e1 = ltr.evals_result()
142+
143+
m0, m1 = (
144+
list(e0["validation_0"].values())[-1][-1],
145+
list(e1["validation_0"].values())[-1][-1],
146+
)
147+
assert m0 != m1

src/common/ranking_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
7979
// unbiased
8080
bool lambdarank_unbiased{false};
8181
bool lambdarank_normalization{true};
82+
bool lambdarank_score_normalization{true};
8283
double lambdarank_bias_norm{1.0};
8384
// ndcg
8485
bool ndcg_exp_gain{true};
@@ -88,6 +89,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
8889
lambdarank_num_pair_per_sample == that.lambdarank_num_pair_per_sample &&
8990
lambdarank_unbiased == that.lambdarank_unbiased &&
9091
lambdarank_normalization == that.lambdarank_normalization &&
92+
lambdarank_score_normalization == that.lambdarank_score_normalization &&
9193
lambdarank_bias_norm == that.lambdarank_bias_norm && ndcg_exp_gain == that.ndcg_exp_gain;
9294
}
9395
bool operator!=(LambdaRankParam const& that) const { return !(*this == that); }
@@ -139,6 +141,9 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
139141
DMLC_DECLARE_FIELD(lambdarank_normalization)
140142
.set_default(true)
141143
.describe("Whether to normalize the leaf value for lambda rank.");
144+
DMLC_DECLARE_FIELD(lambdarank_score_normalization)
145+
.set_default(true)
146+
.describe("Whether to normalize the delta by prediction score difference.");
142147
DMLC_DECLARE_FIELD(lambdarank_bias_norm)
143148
.set_default(1.0)
144149
.set_lower_bound(0.0)

src/objective/lambdarank_obj.cc

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
/**
2-
* Copyright 2023-2024, XGBoost contributors
2+
* Copyright 2023-2025, XGBoost contributors
33
*/
44
#include "lambdarank_obj.h"
55

6-
#include <dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7-
8-
#include <algorithm> // for transform, copy, fill_n, min, max
9-
#include <cmath> // for pow, log2
10-
#include <cstddef> // for size_t
11-
#include <cstdint> // for int32_t
12-
#include <map> // for operator!=
13-
#include <memory> // for shared_ptr, __shared_ptr_access, allocator
14-
#include <ostream> // for operator<<, basic_ostream
15-
#include <string> // for char_traits, operator<, basic_string, string
16-
#include <tuple> // for apply, make_tuple
17-
#include <type_traits> // for is_floating_point
18-
#include <utility> // for pair, swap
19-
#include <vector> // for vector
20-
21-
#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize
22-
#include "../common/linalg_op.h" // for begin, cbegin, cend
23-
#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24-
#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25-
#include "../common/threading_utils.h" // for ParallelFor, Sched
26-
#include "init_estimation.h" // for FitIntercept
27-
#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28-
#include "xgboost/context.h" // for Context
29-
#include "xgboost/data.h" // for MetaInfo
30-
#include "xgboost/host_device_vector.h" // for HostDeviceVector
31-
#include "xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32-
#include "xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33-
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34-
#include "xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35-
#include "xgboost/span.h" // for Span, operator!=
36-
#include "xgboost/string_view.h" // for operator<<, StringView
37-
#include "xgboost/task.h" // for ObjInfo
6+
#include <dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7+
8+
#include <algorithm> // for transform, copy, fill_n, min, max
9+
#include <cmath> // for pow, log2
10+
#include <cstddef> // for size_t
11+
#include <cstdint> // for int32_t
12+
#include <map> // for operator!=
13+
#include <memory> // for shared_ptr, __shared_ptr_access, allocator
14+
#include <ostream> // for operator<<, basic_ostream
15+
#include <string> // for char_traits, operator<, basic_string, string
16+
#include <tuple> // for apply, make_tuple
17+
#include <type_traits> // for is_floating_point
18+
#include <utility> // for pair, swap
19+
#include <vector> // for vector
20+
21+
#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize
22+
#include "../common/linalg_op.h" // for begin, cbegin, cend
23+
#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24+
#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25+
#include "../common/threading_utils.h" // for ParallelFor, Sched
26+
#include "init_estimation.h" // for FitIntercept
27+
#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28+
#include "xgboost/context.h" // for Context
29+
#include "xgboost/data.h" // for MetaInfo
30+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
31+
#include "xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32+
#include "xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33+
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34+
#include "xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35+
#include "xgboost/span.h" // for Span, operator!=
36+
#include "xgboost/string_view.h" // for operator<<, StringView
37+
#include "xgboost/task.h" // for ObjInfo
3838

3939
namespace xgboost::obj {
4040
namespace cpu_impl {
@@ -115,9 +115,8 @@ class LambdaRankObj : public FitIntercept {
115115
// This function doesn't have sycl-specific implementation yet.
116116
// For that reason we transfer data to host in case of sycl is used for propper execution.
117117
auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device();
118-
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(device),
119-
lj_full_.View(device), &ti_plus_, &tj_minus_,
120-
&li_, &lj_, p_cache_);
118+
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(device), lj_full_.View(device),
119+
&ti_plus_, &tj_minus_, &li_, &lj_, p_cache_);
121120
}
122121

123122
li_full_.Data()->Fill(0.0);
@@ -163,7 +162,7 @@ class LambdaRankObj : public FitIntercept {
163162
}
164163

165164
// Calculate lambda gradient for each group on CPU.
166-
template <bool unbiased, typename Delta>
165+
template <bool unbiased, bool norm_by_diff, typename Delta>
167166
void CalcLambdaForGroup(std::int32_t iter, common::Span<float const> g_predt,
168167
linalg::VectorView<float const> g_label, float w,
169168
common::Span<std::size_t const> g_rank, bst_group_t g, Delta delta,
@@ -180,7 +179,9 @@ class LambdaRankObj : public FitIntercept {
180179
// https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298
181180
double sum_lambda{0.0};
182181

183-
auto delta_op = [&](auto const&... args) { return delta(args..., g); };
182+
auto delta_op = [&](auto const&... args) {
183+
return delta(args..., g);
184+
};
184185

185186
auto loop = [&](std::size_t i, std::size_t j) {
186187
// higher/lower on the target ranked list
@@ -193,8 +194,8 @@ class LambdaRankObj : public FitIntercept {
193194
}
194195

195196
double cost;
196-
auto pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
197-
ti_plus, tj_minus, &cost);
197+
auto pg = LambdaGrad<unbiased, norm_by_diff>(g_label, g_predt, g_rank, rank_high, rank_low,
198+
delta_op, ti_plus, tj_minus, &cost);
198199
auto ng = Repulse(pg);
199200

200201
std::size_t idx_high = g_rank[rank_high];
@@ -349,7 +350,14 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
349350
static_assert(std::is_floating_point_v<decltype(y_high)>);
350351
return DeltaNDCG<exp_gain>(y_high, y_low, rank_high, rank_low, inv_IDCG(g), discount);
351352
};
352-
this->CalcLambdaForGroup<unbiased>(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
353+
354+
if (this->param_.lambdarank_score_normalization) {
355+
this->CalcLambdaForGroup<unbiased, true>(iter, g_predt, g_label, w, g_rank, g, delta,
356+
g_gpair);
357+
} else {
358+
this->CalcLambdaForGroup<unbiased, false>(iter, g_predt, g_label, w, g_rank, g, delta,
359+
g_gpair);
360+
}
353361
}
354362

355363
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
@@ -372,7 +380,9 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
372380
auto h_predt = predt.ConstHostSpan();
373381
auto h_label = info.labels.HostView();
374382
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
375-
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
383+
auto make_range = [&](bst_group_t g) {
384+
return linalg::Range(gptr[g], gptr[g + 1]);
385+
};
376386

377387
auto dct = GetCache()->Discount(ctx_);
378388
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
@@ -496,7 +506,9 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
496506
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
497507
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
498508

499-
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
509+
auto make_range = [&](bst_group_t g) {
510+
return linalg::Range(gptr[g], gptr[g + 1]);
511+
};
500512

501513
cpu_impl::MAPStat(ctx_, h_label, rank_idx, GetCache());
502514
auto n_rel = GetCache()->NumRelevant(ctx_);
@@ -528,9 +540,17 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
528540
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair);
529541

530542
if (param_.lambdarank_unbiased) {
531-
std::apply(&LambdaRankMAP::CalcLambdaForGroup<true, D>, args);
543+
if (this->param_.lambdarank_score_normalization) {
544+
std::apply(&LambdaRankMAP::CalcLambdaForGroup<true, true, D>, args);
545+
} else {
546+
std::apply(&LambdaRankMAP::CalcLambdaForGroup<true, false, D>, args);
547+
}
532548
} else {
533-
std::apply(&LambdaRankMAP::CalcLambdaForGroup<false, D>, args);
549+
if (this->param_.lambdarank_score_normalization) {
550+
std::apply(&LambdaRankMAP::CalcLambdaForGroup<false, true, D>, args);
551+
} else {
552+
std::apply(&LambdaRankMAP::CalcLambdaForGroup<false, false, D>, args);
553+
}
534554
}
535555
});
536556
}
@@ -583,10 +603,14 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
583603
auto h_predt = predt.ConstHostSpan();
584604
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
585605

586-
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
606+
auto make_range = [&](bst_group_t g) {
607+
return linalg::Range(gptr[g], gptr[g + 1]);
608+
};
587609
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
588610

589-
auto delta = [](auto...) { return 1.0; };
611+
auto delta = [](auto...) {
612+
return 1.0;
613+
};
590614
using D = decltype(delta);
591615

592616
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
@@ -599,9 +623,17 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
599623

600624
auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
601625
if (param_.lambdarank_unbiased) {
602-
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<true, D>, args);
626+
if (this->param_.lambdarank_score_normalization) {
627+
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<true, true, D>, args);
628+
} else {
629+
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<true, false, D>, args);
630+
}
603631
} else {
604-
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<false, D>, args);
632+
if (this->param_.lambdarank_score_normalization) {
633+
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<false, true, D>, args);
634+
} else {
635+
std::apply(&LambdaRankPairwise::CalcLambdaForGroup<false, false, D>, args);
636+
}
605637
}
606638
});
607639
}

0 commit comments

Comments
 (0)