33 *
44 * \brief CUDA implementation of lambdarank.
55 */
6+ #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
67#include < thrust/fill.h> // for fill_n
78#include < thrust/for_each.h> // for for_each_n
89#include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
910#include < thrust/iterator/zip_iterator.h> // for make_zip_iterator
1011#include < thrust/tuple.h> // for make_tuple, tuple, tie, get
1112
12- #include < algorithm> // for min
13- #include < cassert> // for assert
14- #include < cmath> // for abs, log2, isinf
15- #include < cstddef> // for size_t
16- #include < cstdint> // for int32_t
17- #include < memory> // for shared_ptr
13+ #include < algorithm> // for min
14+ #include < cassert> // for assert
15+ #include < cmath> // for abs, log2, isinf
16+ #include < cstddef> // for size_t
17+ #include < cstdint> // for int32_t
18+ #include < memory> // for shared_ptr
1819#include < utility>
1920
2021#include " ../common/algorithm.cuh" // for SegmentedArgSort
3132#include " xgboost/host_device_vector.h" // for HostDeviceVector
3233#include " xgboost/linalg.h" // for VectorView, Range, Vector
3334#include " xgboost/logging.h"
34- #include " xgboost/span.h" // for Span
35+ #include " xgboost/span.h" // for Span
3536
3637namespace xgboost ::obj {
3738DMLC_REGISTRY_FILE_TAG (lambdarank_obj_cu);
@@ -82,7 +83,7 @@ struct GetGradOp {
8283 MakePairsOp<has_truncation> make_pair;
8384 Delta delta;
8485
85- bool need_update;
86+ bool const need_update;
8687
8788 auto __device__ operator ()(std::size_t idx) -> GradCostNorm {
8889 auto const & args = make_pair.args ;
@@ -95,6 +96,7 @@ struct GetGradOp {
9596 auto g_predt = args.predts .subspan (data_group_begin, n_data);
9697 auto g_gpair = args.gpairs .Slice (linalg::Range (data_group_begin, data_group_begin + n_data));
9798 auto g_rank = args.d_sorted_idx .subspan (data_group_begin, n_data);
99+ auto n_pairs = args.n_pairs ;
98100
99101 auto [i, j] = make_pair (idx, g);
100102
@@ -108,7 +110,9 @@ struct GetGradOp {
108110
109111 double cost{0 };
110112
111- auto delta_op = [&](auto const &... args) { return delta (args..., g); };
113+ auto delta_op = [&](auto const &... args) {
114+ return delta (args..., g);
115+ };
112116 GradientPair pg =
113117 LambdaGrad<unbiased, norm_by_diff>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
114118 args.ti_plus , args.tj_minus , &cost);
@@ -118,7 +122,6 @@ struct GetGradOp {
118122
119123 if (need_update) {
120124 // second run, update the gradient
121-
122125 auto ng = Repulse (pg);
123126
124127 auto gr = args.d_roundings (g);
@@ -153,6 +156,7 @@ struct GetGradOp {
153156 }
154157 }
155158 }
159+
156160 return thrust::make_tuple (GradientPair{std::abs (pg.GetGrad ()), std::abs (pg.GetHess ())},
157161 std::abs (cost), -2.0 * static_cast <double >(pg.GetGrad ()));
158162 }
@@ -215,12 +219,12 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
215219 auto hess = std::max (lg.GetHess (), rg.GetHess ());
216220 auto cost = std::max (thrust::get<1 >(l), thrust::get<1 >(r));
217221 double sum_lambda = thrust::get<2 >(l) + thrust::get<2 >(r);
218- return thrust::make_tuple (GradientPair{std::abs ( grad), std::abs ( hess) }, cost, sum_lambda);
222+ return thrust::make_tuple (GradientPair{grad, hess}, cost, sum_lambda);
219223 };
220224 auto init = thrust::make_tuple (GradientPair{0 .0f , 0 .0f }, 0.0 , 0.0 );
221225 common::Span<GradCostNorm> d_max_lambdas = p_cache->MaxLambdas <GradCostNorm>(ctx, n_groups);
222226 CHECK_EQ (n_groups * sizeof (GradCostNorm), d_max_lambdas.size_bytes ());
223-
227+ // Reduce by group.
224228 std::size_t bytes;
225229 cub::DeviceSegmentedReduce::Reduce (nullptr , bytes, val_it, d_max_lambdas.data (), n_groups,
226230 d_threads_group_ptr.data (), d_threads_group_ptr.data () + 1 ,
@@ -267,22 +271,35 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
267271 */
268272 auto d_weights = common::MakeOptionalWeights (ctx, info.weights_ );
269273 auto w_norm = p_cache->WeightNorm ();
270- auto norm = p_cache->Param ().lambdarank_normalization ;
274+ auto need_norm = p_cache->Param ().lambdarank_normalization ;
275+ auto n_pairs = p_cache->Param ().NumPair ();
276+ bool is_mean = p_cache->Param ().IsMean ();
277+ CHECK_EQ (is_mean, !has_truncation);
271278 thrust::for_each_n (ctx->CUDACtx ()->CTP (), thrust::make_counting_iterator (0ul ), d_gpair.Size (),
272279 [=] XGBOOST_DEVICE (std::size_t i) mutable {
273280 auto g = dh::SegmentId (d_gptr, i);
274- auto sum_lambda = thrust::get<2 >(d_max_lambdas[g]);
275- // Normalization
276- if (sum_lambda > 0.0 && norm) {
277- double norm = std::log2 (1.0 + sum_lambda) / sum_lambda;
281+ if (need_norm) {
282+ double norm = 1.0 ;
283+ if (has_truncation) {
284+ // Normalize using gradient for top-k.
285+ auto sum_lambda = thrust::get<2 >(d_max_lambdas[g]);
286+ if (sum_lambda > 0.0 ) {
287+ norm = std::log2 (1.0 + sum_lambda) / sum_lambda;
288+ }
289+ } else {
290+ // Normalize using the number of pairs for mean.
291+ double scale = 1.0 / static_cast <double >(n_pairs);
292+ norm = scale;
293+ }
278294 d_gpair (i, 0 ) *= norm;
279295 }
296+
280297 d_gpair (i, 0 ) *= (d_weights[g] * w_norm);
281298 });
282299}
283300
284301/* *
285- * \ brief Handles boilerplate code like getting device span .
302+ * @ brief Handles boilerplate code like getting device spans .
286303 */
287304template <bool norm_by_diff, typename Delta>
288305void Launch (Context const * ctx, std::int32_t iter, HostDeviceVector<float > const & preds,
@@ -302,7 +319,6 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
302319 out_gpair->Reshape (preds.Size (), 1 );
303320
304321 CHECK (p_cache);
305-
306322 auto d_rounding = p_cache->CUDARounding (ctx);
307323 auto d_cost_rounding = p_cache->CUDACostRounding (ctx);
308324
@@ -325,9 +341,10 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
325341 d_y_sorted_idx = SortY (ctx, info, rank_idx, p_cache);
326342 }
327343
328- KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
329- rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data (),
330- d_y_sorted_idx, iter};
344+ auto n_pairs = p_cache->Param ().NumPair ();
345+ KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
346+ rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data (),
347+ n_pairs, d_y_sorted_idx, iter};
331348
332349 // dispatch based on unbiased and truncation
333350 if (p_cache->Param ().HasTruncation ()) {
0 commit comments