11/* *
2- * Copyright 2023-2024 , XGBoost contributors
2+ * Copyright 2023-2025 , XGBoost contributors
33 */
44#include " lambdarank_obj.h"
55
6- #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7-
8- #include < algorithm> // for transform, copy, fill_n, min, max
9- #include < cmath> // for pow, log2
10- #include < cstddef> // for size_t
11- #include < cstdint> // for int32_t
12- #include < map> // for operator!=
13- #include < memory> // for shared_ptr, __shared_ptr_access, allocator
14- #include < ostream> // for operator<<, basic_ostream
15- #include < string> // for char_traits, operator<, basic_string, string
16- #include < tuple> // for apply, make_tuple
17- #include < type_traits> // for is_floating_point
18- #include < utility> // for pair, swap
19- #include < vector> // for vector
20-
21- #include " ../common/error_msg.h" // for GroupWeight, LabelScoreSize
22- #include " ../common/linalg_op.h" // for begin, cbegin, cend
23- #include " ../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24- #include " ../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25- #include " ../common/threading_utils.h" // for ParallelFor, Sched
26- #include " init_estimation.h" // for FitIntercept
27- #include " xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28- #include " xgboost/context.h" // for Context
29- #include " xgboost/data.h" // for MetaInfo
30- #include " xgboost/host_device_vector.h" // for HostDeviceVector
31- #include " xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32- #include " xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33- #include " xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34- #include " xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35- #include " xgboost/span.h" // for Span, operator!=
36- #include " xgboost/string_view.h" // for operator<<, StringView
37- #include " xgboost/task.h" // for ObjInfo
6+ #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7+
8+ #include < algorithm> // for transform, copy, fill_n, min, max
9+ #include < cmath> // for pow, log2
10+ #include < cstddef> // for size_t
11+ #include < cstdint> // for int32_t
12+ #include < map> // for operator!=
13+ #include < memory> // for shared_ptr, __shared_ptr_access, allocator
14+ #include < ostream> // for operator<<, basic_ostream
15+ #include < string> // for char_traits, operator<, basic_string, string
16+ #include < tuple> // for apply, make_tuple
17+ #include < type_traits> // for is_floating_point
18+ #include < utility> // for pair, swap
19+ #include < vector> // for vector
20+
21+ #include " ../common/error_msg.h" // for GroupWeight, LabelScoreSize
22+ #include " ../common/linalg_op.h" // for begin, cbegin, cend
23+ #include " ../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24+ #include " ../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25+ #include " ../common/threading_utils.h" // for ParallelFor, Sched
26+ #include " init_estimation.h" // for FitIntercept
27+ #include " xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28+ #include " xgboost/context.h" // for Context
29+ #include " xgboost/data.h" // for MetaInfo
30+ #include " xgboost/host_device_vector.h" // for HostDeviceVector
31+ #include " xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32+ #include " xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33+ #include " xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34+ #include " xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35+ #include " xgboost/span.h" // for Span, operator!=
36+ #include " xgboost/string_view.h" // for operator<<, StringView
37+ #include " xgboost/task.h" // for ObjInfo
3838
3939namespace xgboost ::obj {
4040namespace cpu_impl {
@@ -115,9 +115,8 @@ class LambdaRankObj : public FitIntercept {
115115 // This function doesn't have sycl-specific implementation yet.
116116 // For that reason we transfer data to host in case of sycl is used for propper execution.
117117 auto device = ctx_->Device ().IsSycl () ? DeviceOrd::CPU () : ctx_->Device ();
118- cpu_impl::LambdaRankUpdatePositionBias (ctx_, li_full_.View (device),
119- lj_full_.View (device), &ti_plus_, &tj_minus_,
120- &li_, &lj_, p_cache_);
118+ cpu_impl::LambdaRankUpdatePositionBias (ctx_, li_full_.View (device), lj_full_.View (device),
119+ &ti_plus_, &tj_minus_, &li_, &lj_, p_cache_);
121120 }
122121
123122 li_full_.Data ()->Fill (0.0 );
@@ -163,7 +162,7 @@ class LambdaRankObj : public FitIntercept {
163162 }
164163
165164 // Calculate lambda gradient for each group on CPU.
166- template <bool unbiased, typename Delta>
165+ template <bool unbiased, bool norm_by_diff, typename Delta>
167166 void CalcLambdaForGroup (std::int32_t iter, common::Span<float const > g_predt,
168167 linalg::VectorView<float const > g_label, float w,
169168 common::Span<std::size_t const > g_rank, bst_group_t g, Delta delta,
@@ -180,7 +179,9 @@ class LambdaRankObj : public FitIntercept {
180179 // https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298
181180 double sum_lambda{0.0 };
182181
183- auto delta_op = [&](auto const &... args) { return delta (args..., g); };
182+ auto delta_op = [&](auto const &... args) {
183+ return delta (args..., g);
184+ };
184185
185186 auto loop = [&](std::size_t i, std::size_t j) {
186187 // higher/lower on the target ranked list
@@ -193,8 +194,8 @@ class LambdaRankObj : public FitIntercept {
193194 }
194195
195196 double cost;
196- auto pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op ,
197- ti_plus, tj_minus, &cost);
197+ auto pg = LambdaGrad<unbiased, norm_by_diff >(g_label, g_predt, g_rank, rank_high, rank_low,
198+ delta_op, ti_plus, tj_minus, &cost);
198199 auto ng = Repulse (pg);
199200
200201 std::size_t idx_high = g_rank[rank_high];
@@ -349,7 +350,14 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
349350 static_assert (std::is_floating_point_v<decltype (y_high)>);
350351 return DeltaNDCG<exp_gain>(y_high, y_low, rank_high, rank_low, inv_IDCG (g), discount);
351352 };
352- this ->CalcLambdaForGroup <unbiased>(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
353+
354+ if (this ->param_ .lambdarank_score_normalization ) {
355+ this ->CalcLambdaForGroup <unbiased, true >(iter, g_predt, g_label, w, g_rank, g, delta,
356+ g_gpair);
357+ } else {
358+ this ->CalcLambdaForGroup <unbiased, false >(iter, g_predt, g_label, w, g_rank, g, delta,
359+ g_gpair);
360+ }
353361 }
354362
355363 void GetGradientImpl (std::int32_t iter, const HostDeviceVector<float >& predt,
@@ -372,7 +380,9 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
372380 auto h_predt = predt.ConstHostSpan ();
373381 auto h_label = info.labels .HostView ();
374382 auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
375- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
383+ auto make_range = [&](bst_group_t g) {
384+ return linalg::Range (gptr[g], gptr[g + 1 ]);
385+ };
376386
377387 auto dct = GetCache ()->Discount (ctx_);
378388 auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
@@ -496,7 +506,9 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
496506 auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
497507 auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
498508
499- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
509+ auto make_range = [&](bst_group_t g) {
510+ return linalg::Range (gptr[g], gptr[g + 1 ]);
511+ };
500512
501513 cpu_impl::MAPStat (ctx_, h_label, rank_idx, GetCache ());
502514 auto n_rel = GetCache ()->NumRelevant (ctx_);
@@ -528,9 +540,17 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
528540 auto args = std::make_tuple (this , iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair);
529541
530542 if (param_.lambdarank_unbiased ) {
531- std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , D>, args);
543+ if (this ->param_ .lambdarank_score_normalization ) {
544+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , true , D>, args);
545+ } else {
546+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , false , D>, args);
547+ }
532548 } else {
533- std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , D>, args);
549+ if (this ->param_ .lambdarank_score_normalization ) {
550+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , true , D>, args);
551+ } else {
552+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , false , D>, args);
553+ }
534554 }
535555 });
536556 }
@@ -583,10 +603,14 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
583603 auto h_predt = predt.ConstHostSpan ();
584604 auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
585605
586- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
606+ auto make_range = [&](bst_group_t g) {
607+ return linalg::Range (gptr[g], gptr[g + 1 ]);
608+ };
587609 auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
588610
589- auto delta = [](auto ...) { return 1.0 ; };
611+ auto delta = [](auto ...) {
612+ return 1.0 ;
613+ };
590614 using D = decltype (delta);
591615
592616 common::ParallelFor (n_groups, ctx_->Threads (), [&](auto g) {
@@ -599,9 +623,17 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
599623
600624 auto args = std::make_tuple (this , iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
601625 if (param_.lambdarank_unbiased ) {
602- std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , D>, args);
626+ if (this ->param_ .lambdarank_score_normalization ) {
627+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , true , D>, args);
628+ } else {
629+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , false , D>, args);
630+ }
603631 } else {
604- std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , D>, args);
632+ if (this ->param_ .lambdarank_score_normalization ) {
633+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , true , D>, args);
634+ } else {
635+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , false , D>, args);
636+ }
605637 }
606638 });
607639 }
0 commit comments