3
3
*
4
4
* \brief CUDA implementation of lambdarank.
5
5
*/
6
+ #include < dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
6
7
#include < thrust/fill.h> // for fill_n
7
8
#include < thrust/for_each.h> // for for_each_n
8
9
#include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
9
10
#include < thrust/iterator/zip_iterator.h> // for make_zip_iterator
10
11
#include < thrust/tuple.h> // for make_tuple, tuple, tie, get
11
12
12
- #include < algorithm> // for min
13
- #include < cassert> // for assert
14
- #include < cmath> // for abs, log2, isinf
15
- #include < cstddef> // for size_t
16
- #include < cstdint> // for int32_t
17
- #include < memory> // for shared_ptr
13
+ #include < algorithm> // for min
14
+ #include < cassert> // for assert
15
+ #include < cmath> // for abs, log2, isinf
16
+ #include < cstddef> // for size_t
17
+ #include < cstdint> // for int32_t
18
+ #include < memory> // for shared_ptr
18
19
#include < utility>
19
20
20
21
#include " ../common/algorithm.cuh" // for SegmentedArgSort
31
32
#include " xgboost/host_device_vector.h" // for HostDeviceVector
32
33
#include " xgboost/linalg.h" // for VectorView, Range, Vector
33
34
#include " xgboost/logging.h"
34
- #include " xgboost/span.h" // for Span
35
+ #include " xgboost/span.h" // for Span
35
36
36
37
namespace xgboost ::obj {
37
38
DMLC_REGISTRY_FILE_TAG (lambdarank_obj_cu);
@@ -82,7 +83,7 @@ struct GetGradOp {
82
83
MakePairsOp<has_truncation> make_pair;
83
84
Delta delta;
84
85
85
- bool need_update;
86
+ bool const need_update;
86
87
87
88
auto __device__ operator ()(std::size_t idx) -> GradCostNorm {
88
89
auto const & args = make_pair.args ;
@@ -95,6 +96,7 @@ struct GetGradOp {
95
96
auto g_predt = args.predts .subspan (data_group_begin, n_data);
96
97
auto g_gpair = args.gpairs .Slice (linalg::Range (data_group_begin, data_group_begin + n_data));
97
98
auto g_rank = args.d_sorted_idx .subspan (data_group_begin, n_data);
99
+ auto n_pairs = args.n_pairs ;
98
100
99
101
auto [i, j] = make_pair (idx, g);
100
102
@@ -108,7 +110,9 @@ struct GetGradOp {
108
110
109
111
double cost{0 };
110
112
111
- auto delta_op = [&](auto const &... args) { return delta (args..., g); };
113
+ auto delta_op = [&](auto const &... args) {
114
+ return delta (args..., g);
115
+ };
112
116
GradientPair pg =
113
117
LambdaGrad<unbiased, norm_by_diff>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
114
118
args.ti_plus , args.tj_minus , &cost);
@@ -118,7 +122,6 @@ struct GetGradOp {
118
122
119
123
if (need_update) {
120
124
// second run, update the gradient
121
-
122
125
auto ng = Repulse (pg);
123
126
124
127
auto gr = args.d_roundings (g);
@@ -153,6 +156,7 @@ struct GetGradOp {
153
156
}
154
157
}
155
158
}
159
+
156
160
return thrust::make_tuple (GradientPair{std::abs (pg.GetGrad ()), std::abs (pg.GetHess ())},
157
161
std::abs (cost), -2.0 * static_cast <double >(pg.GetGrad ()));
158
162
}
@@ -215,12 +219,12 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
215
219
auto hess = std::max (lg.GetHess (), rg.GetHess ());
216
220
auto cost = std::max (thrust::get<1 >(l), thrust::get<1 >(r));
217
221
double sum_lambda = thrust::get<2 >(l) + thrust::get<2 >(r);
218
- return thrust::make_tuple (GradientPair{std::abs ( grad), std::abs ( hess) }, cost, sum_lambda);
222
+ return thrust::make_tuple (GradientPair{grad, hess}, cost, sum_lambda);
219
223
};
220
224
auto init = thrust::make_tuple (GradientPair{0 .0f , 0 .0f }, 0.0 , 0.0 );
221
225
common::Span<GradCostNorm> d_max_lambdas = p_cache->MaxLambdas <GradCostNorm>(ctx, n_groups);
222
226
CHECK_EQ (n_groups * sizeof (GradCostNorm), d_max_lambdas.size_bytes ());
223
-
227
+ // Reduce by group.
224
228
std::size_t bytes;
225
229
cub::DeviceSegmentedReduce::Reduce (nullptr , bytes, val_it, d_max_lambdas.data (), n_groups,
226
230
d_threads_group_ptr.data (), d_threads_group_ptr.data () + 1 ,
@@ -267,22 +271,35 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
267
271
*/
268
272
auto d_weights = common::MakeOptionalWeights (ctx, info.weights_ );
269
273
auto w_norm = p_cache->WeightNorm ();
270
- auto norm = p_cache->Param ().lambdarank_normalization ;
274
+ auto need_norm = p_cache->Param ().lambdarank_normalization ;
275
+ auto n_pairs = p_cache->Param ().NumPair ();
276
+ bool is_mean = p_cache->Param ().IsMean ();
277
+ CHECK_EQ (is_mean, !has_truncation);
271
278
thrust::for_each_n (ctx->CUDACtx ()->CTP (), thrust::make_counting_iterator (0ul ), d_gpair.Size (),
272
279
[=] XGBOOST_DEVICE (std::size_t i) mutable {
273
280
auto g = dh::SegmentId (d_gptr, i);
274
- auto sum_lambda = thrust::get<2 >(d_max_lambdas[g]);
275
- // Normalization
276
- if (sum_lambda > 0.0 && norm) {
277
- double norm = std::log2 (1.0 + sum_lambda) / sum_lambda;
281
+ if (need_norm) {
282
+ double norm = 1.0 ;
283
+ if (has_truncation) {
284
+ // Normalize using gradient for top-k.
285
+ auto sum_lambda = thrust::get<2 >(d_max_lambdas[g]);
286
+ if (sum_lambda > 0.0 ) {
287
+ norm = std::log2 (1.0 + sum_lambda) / sum_lambda;
288
+ }
289
+ } else {
290
+ // Normalize using the number of pairs for mean.
291
+ double scale = 1.0 / static_cast <double >(n_pairs);
292
+ norm = scale;
293
+ }
278
294
d_gpair (i, 0 ) *= norm;
279
295
}
296
+
280
297
d_gpair (i, 0 ) *= (d_weights[g] * w_norm);
281
298
});
282
299
}
283
300
284
301
/* *
285
- * \ brief Handles boilerplate code like getting device span .
302
+ * @ brief Handles boilerplate code like getting device spans .
286
303
*/
287
304
template <bool norm_by_diff, typename Delta>
288
305
void Launch (Context const * ctx, std::int32_t iter, HostDeviceVector<float > const & preds,
@@ -302,7 +319,6 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
302
319
out_gpair->Reshape (preds.Size (), 1 );
303
320
304
321
CHECK (p_cache);
305
-
306
322
auto d_rounding = p_cache->CUDARounding (ctx);
307
323
auto d_cost_rounding = p_cache->CUDACostRounding (ctx);
308
324
@@ -325,9 +341,10 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const
325
341
d_y_sorted_idx = SortY (ctx, info, rank_idx, p_cache);
326
342
}
327
343
328
- KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
329
- rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data (),
330
- d_y_sorted_idx, iter};
344
+ auto n_pairs = p_cache->Param ().NumPair ();
345
+ KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
346
+ rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data (),
347
+ n_pairs, d_y_sorted_idx, iter};
331
348
332
349
// dispatch based on unbiased and truncation
333
350
if (p_cache->Param ().HasTruncation ()) {
0 commit comments