1
1
/* *
2
- * Copyright 2023-2024 , XGBoost contributors
2
+ * Copyright 2023-2025 , XGBoost contributors
3
3
*/
4
4
#include " lambdarank_obj.h"
5
5
6
- #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7
-
8
- #include < algorithm> // for transform, copy, fill_n, min, max
9
- #include < cmath> // for pow, log2
10
- #include < cstddef> // for size_t
11
- #include < cstdint> // for int32_t
12
- #include < map> // for operator!=
13
- #include < memory> // for shared_ptr, __shared_ptr_access, allocator
14
- #include < ostream> // for operator<<, basic_ostream
15
- #include < string> // for char_traits, operator<, basic_string, string
16
- #include < tuple> // for apply, make_tuple
17
- #include < type_traits> // for is_floating_point
18
- #include < utility> // for pair, swap
19
- #include < vector> // for vector
20
-
21
- #include " ../common/error_msg.h" // for GroupWeight, LabelScoreSize
22
- #include " ../common/linalg_op.h" // for begin, cbegin, cend
23
- #include " ../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24
- #include " ../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25
- #include " ../common/threading_utils.h" // for ParallelFor, Sched
26
- #include " init_estimation.h" // for FitIntercept
27
- #include " xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28
- #include " xgboost/context.h" // for Context
29
- #include " xgboost/data.h" // for MetaInfo
30
- #include " xgboost/host_device_vector.h" // for HostDeviceVector
31
- #include " xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32
- #include " xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33
- #include " xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34
- #include " xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35
- #include " xgboost/span.h" // for Span, operator!=
36
- #include " xgboost/string_view.h" // for operator<<, StringView
37
- #include " xgboost/task.h" // for ObjInfo
6
+ #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
7
+
8
+ #include < algorithm> // for transform, copy, fill_n, min, max
9
+ #include < cmath> // for pow, log2
10
+ #include < cstddef> // for size_t
11
+ #include < cstdint> // for int32_t
12
+ #include < map> // for operator!=
13
+ #include < memory> // for shared_ptr, __shared_ptr_access, allocator
14
+ #include < ostream> // for operator<<, basic_ostream
15
+ #include < string> // for char_traits, operator<, basic_string, string
16
+ #include < tuple> // for apply, make_tuple
17
+ #include < type_traits> // for is_floating_point
18
+ #include < utility> // for pair, swap
19
+ #include < vector> // for vector
20
+
21
+ #include " ../common/error_msg.h" // for GroupWeight, LabelScoreSize
22
+ #include " ../common/linalg_op.h" // for begin, cbegin, cend
23
+ #include " ../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
24
+ #include " ../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
25
+ #include " ../common/threading_utils.h" // for ParallelFor, Sched
26
+ #include " init_estimation.h" // for FitIntercept
27
+ #include " xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
28
+ #include " xgboost/context.h" // for Context
29
+ #include " xgboost/data.h" // for MetaInfo
30
+ #include " xgboost/host_device_vector.h" // for HostDeviceVector
31
+ #include " xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
32
+ #include " xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
33
+ #include " xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
34
+ #include " xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
35
+ #include " xgboost/span.h" // for Span, operator!=
36
+ #include " xgboost/string_view.h" // for operator<<, StringView
37
+ #include " xgboost/task.h" // for ObjInfo
38
38
39
39
namespace xgboost ::obj {
40
40
namespace cpu_impl {
@@ -115,9 +115,8 @@ class LambdaRankObj : public FitIntercept {
115
115
// This function doesn't have sycl-specific implementation yet.
116
116
// For that reason we transfer data to host in case of sycl is used for propper execution.
117
117
auto device = ctx_->Device ().IsSycl () ? DeviceOrd::CPU () : ctx_->Device ();
118
- cpu_impl::LambdaRankUpdatePositionBias (ctx_, li_full_.View (device),
119
- lj_full_.View (device), &ti_plus_, &tj_minus_,
120
- &li_, &lj_, p_cache_);
118
+ cpu_impl::LambdaRankUpdatePositionBias (ctx_, li_full_.View (device), lj_full_.View (device),
119
+ &ti_plus_, &tj_minus_, &li_, &lj_, p_cache_);
121
120
}
122
121
123
122
li_full_.Data ()->Fill (0.0 );
@@ -163,7 +162,7 @@ class LambdaRankObj : public FitIntercept {
163
162
}
164
163
165
164
// Calculate lambda gradient for each group on CPU.
166
- template <bool unbiased, typename Delta>
165
+ template <bool unbiased, bool norm_by_diff, typename Delta>
167
166
void CalcLambdaForGroup (std::int32_t iter, common::Span<float const > g_predt,
168
167
linalg::VectorView<float const > g_label, float w,
169
168
common::Span<std::size_t const > g_rank, bst_group_t g, Delta delta,
@@ -180,7 +179,9 @@ class LambdaRankObj : public FitIntercept {
180
179
// https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298
181
180
double sum_lambda{0.0 };
182
181
183
- auto delta_op = [&](auto const &... args) { return delta (args..., g); };
182
+ auto delta_op = [&](auto const &... args) {
183
+ return delta (args..., g);
184
+ };
184
185
185
186
auto loop = [&](std::size_t i, std::size_t j) {
186
187
// higher/lower on the target ranked list
@@ -193,8 +194,8 @@ class LambdaRankObj : public FitIntercept {
193
194
}
194
195
195
196
double cost;
196
- auto pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op ,
197
- ti_plus, tj_minus, &cost);
197
+ auto pg = LambdaGrad<unbiased, norm_by_diff >(g_label, g_predt, g_rank, rank_high, rank_low,
198
+ delta_op, ti_plus, tj_minus, &cost);
198
199
auto ng = Repulse (pg);
199
200
200
201
std::size_t idx_high = g_rank[rank_high];
@@ -349,7 +350,14 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
349
350
static_assert (std::is_floating_point_v<decltype (y_high)>);
350
351
return DeltaNDCG<exp_gain>(y_high, y_low, rank_high, rank_low, inv_IDCG (g), discount);
351
352
};
352
- this ->CalcLambdaForGroup <unbiased>(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
353
+
354
+ if (this ->param_ .lambdarank_score_normalization ) {
355
+ this ->CalcLambdaForGroup <unbiased, true >(iter, g_predt, g_label, w, g_rank, g, delta,
356
+ g_gpair);
357
+ } else {
358
+ this ->CalcLambdaForGroup <unbiased, false >(iter, g_predt, g_label, w, g_rank, g, delta,
359
+ g_gpair);
360
+ }
353
361
}
354
362
355
363
void GetGradientImpl (std::int32_t iter, const HostDeviceVector<float >& predt,
@@ -372,7 +380,9 @@ class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
372
380
auto h_predt = predt.ConstHostSpan ();
373
381
auto h_label = info.labels .HostView ();
374
382
auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
375
- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
383
+ auto make_range = [&](bst_group_t g) {
384
+ return linalg::Range (gptr[g], gptr[g + 1 ]);
385
+ };
376
386
377
387
auto dct = GetCache ()->Discount (ctx_);
378
388
auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
@@ -496,7 +506,9 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
496
506
auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
497
507
auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
498
508
499
- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
509
+ auto make_range = [&](bst_group_t g) {
510
+ return linalg::Range (gptr[g], gptr[g + 1 ]);
511
+ };
500
512
501
513
cpu_impl::MAPStat (ctx_, h_label, rank_idx, GetCache ());
502
514
auto n_rel = GetCache ()->NumRelevant (ctx_);
@@ -528,9 +540,17 @@ class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
528
540
auto args = std::make_tuple (this , iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair);
529
541
530
542
if (param_.lambdarank_unbiased ) {
531
- std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , D>, args);
543
+ if (this ->param_ .lambdarank_score_normalization ) {
544
+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , true , D>, args);
545
+ } else {
546
+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<true , false , D>, args);
547
+ }
532
548
} else {
533
- std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , D>, args);
549
+ if (this ->param_ .lambdarank_score_normalization ) {
550
+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , true , D>, args);
551
+ } else {
552
+ std::apply (&LambdaRankMAP::CalcLambdaForGroup<false , false , D>, args);
553
+ }
534
554
}
535
555
});
536
556
}
@@ -583,10 +603,14 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
583
603
auto h_predt = predt.ConstHostSpan ();
584
604
auto h_weight = common::MakeOptionalWeights (ctx_, info.weights_ );
585
605
586
- auto make_range = [&](bst_group_t g) { return linalg::Range (gptr[g], gptr[g + 1 ]); };
606
+ auto make_range = [&](bst_group_t g) {
607
+ return linalg::Range (gptr[g], gptr[g + 1 ]);
608
+ };
587
609
auto rank_idx = p_cache_->SortedIdx (ctx_, h_predt);
588
610
589
- auto delta = [](auto ...) { return 1.0 ; };
611
+ auto delta = [](auto ...) {
612
+ return 1.0 ;
613
+ };
590
614
using D = decltype (delta);
591
615
592
616
common::ParallelFor (n_groups, ctx_->Threads (), [&](auto g) {
@@ -599,9 +623,17 @@ class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::Ranking
599
623
600
624
auto args = std::make_tuple (this , iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
601
625
if (param_.lambdarank_unbiased ) {
602
- std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , D>, args);
626
+ if (this ->param_ .lambdarank_score_normalization ) {
627
+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , true , D>, args);
628
+ } else {
629
+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<true , false , D>, args);
630
+ }
603
631
} else {
604
- std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , D>, args);
632
+ if (this ->param_ .lambdarank_score_normalization ) {
633
+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , true , D>, args);
634
+ } else {
635
+ std::apply (&LambdaRankPairwise::CalcLambdaForGroup<false , false , D>, args);
636
+ }
605
637
}
606
638
});
607
639
}
0 commit comments