Skip to content

Commit 210915c

Browse files
authored
Use integer gradients in gpu_hist split evaluation (dmlc#8274)
1 parent c68684f commit 210915c

File tree

12 files changed

+224
-292
lines changed

12 files changed

+224
-292
lines changed

include/xgboost/base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ using GradientPairPrecise = detail::GradientPairInternal<double>;
264264
* we don't accidentally use it in gain calculations.*/
265265
class GradientPairInt64 {
266266
using T = int64_t;
267-
T grad_;
268-
T hess_;
267+
T grad_ = 0;
268+
T hess_ = 0;
269269

270270
public:
271271
using ValueT = T;

src/tree/gpu_hist/evaluate_splits.cu

Lines changed: 79 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@ namespace xgboost {
1515
namespace tree {
1616

1717
// With constraints
18-
XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
19-
const GradientPairPrecise &missing,
20-
const GradientPairPrecise &parent_sum,
18+
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
19+
const GradientPairInt64 &missing,
20+
const GradientPairInt64 &parent_sum,
2121
const GPUTrainingParam &param, bst_node_t nidx,
2222
bst_feature_t fidx,
2323
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
24-
bool &missing_left_out) { // NOLINT
24+
bool &missing_left_out, const GradientQuantiser& quantiser) { // NOLINT
2525
const auto left_sum = scan + missing;
26-
float missing_left_gain =
27-
evaluator.CalcSplitGain(param, nidx, fidx, left_sum, parent_sum - left_sum);
28-
float missing_right_gain = evaluator.CalcSplitGain(param, nidx, fidx, scan, parent_sum - scan);
26+
float missing_left_gain = evaluator.CalcSplitGain(
27+
param, nidx, fidx, quantiser.ToFloatingPoint(left_sum),
28+
quantiser.ToFloatingPoint(parent_sum - left_sum));
29+
float missing_right_gain = evaluator.CalcSplitGain(
30+
param, nidx, fidx, quantiser.ToFloatingPoint(scan),
31+
quantiser.ToFloatingPoint(parent_sum - scan));
2932

3033
missing_left_out = missing_left_gain > missing_right_gain;
3134
return missing_left_out?missing_left_gain:missing_right_gain;
@@ -42,9 +45,9 @@ template <int kBlockSize>
4245
class EvaluateSplitAgent {
4346
public:
4447
using ArgMaxT = cub::KeyValuePair<int, float>;
45-
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
48+
using BlockScanT = cub::BlockScan<GradientPairInt64, kBlockSize>;
4649
using MaxReduceT = cub::WarpReduce<ArgMaxT>;
47-
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
50+
using SumReduceT = cub::WarpReduce<GradientPairInt64>;
4851

4952
struct TempStorage {
5053
typename BlockScanT::TempStorage scan;
@@ -59,67 +62,67 @@ class EvaluateSplitAgent {
5962
const uint32_t gidx_end; // end bin for i^th feature
6063
const dh::LDGIterator<float> feature_values;
6164
const GradientPairInt64 *node_histogram;
62-
const GradientQuantizer &rounding;
63-
const GradientPairPrecise parent_sum;
64-
const GradientPairPrecise missing;
65+
const GradientQuantiser &rounding;
66+
const GradientPairInt64 parent_sum;
67+
const GradientPairInt64 missing;
6568
const GPUTrainingParam &param;
6669
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator;
6770
TempStorage *temp_storage;
68-
SumCallbackOp<GradientPairPrecise> prefix_op;
71+
SumCallbackOp<GradientPairInt64> prefix_op;
6972
static float constexpr kNullGain = -std::numeric_limits<bst_float>::infinity();
7073

71-
__device__ EvaluateSplitAgent(TempStorage *temp_storage, int fidx,
72-
const EvaluateSplitInputs &inputs,
73-
const EvaluateSplitSharedInputs &shared_inputs,
74-
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
75-
: temp_storage(temp_storage),
76-
nidx(inputs.nidx),
77-
fidx(fidx),
74+
__device__ EvaluateSplitAgent(
75+
TempStorage *temp_storage, int fidx, const EvaluateSplitInputs &inputs,
76+
const EvaluateSplitSharedInputs &shared_inputs,
77+
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
78+
: temp_storage(temp_storage), nidx(inputs.nidx), fidx(fidx),
7879
min_fvalue(__ldg(shared_inputs.min_fvalue.data() + fidx)),
7980
gidx_begin(__ldg(shared_inputs.feature_segments.data() + fidx)),
8081
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
8182
feature_values(shared_inputs.feature_values.data()),
8283
node_histogram(inputs.gradient_histogram.data()),
8384
rounding(shared_inputs.rounding),
84-
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
85-
param(shared_inputs.param),
86-
evaluator(evaluator),
85+
parent_sum(dh::LDGIterator<GradientPairInt64>(&inputs.parent_sum)[0]),
86+
param(shared_inputs.param), evaluator(evaluator),
8787
missing(parent_sum - ReduceFeature()) {
88-
static_assert(kBlockSize == 32,
89-
"This kernel relies on the assumption block_size == warp_size");
88+
static_assert(
89+
kBlockSize == 32,
90+
"This kernel relies on the assumption block_size == warp_size");
91+
// There should be no missing value gradients for a dense matrix
92+
KERNEL_CHECK(!shared_inputs.is_dense || missing.GetQuantisedHess() == 0);
9093
}
91-
__device__ GradientPairPrecise ReduceFeature() {
92-
GradientPairPrecise local_sum;
93-
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end; idx += kBlockSize) {
94+
__device__ GradientPairInt64 ReduceFeature() {
95+
GradientPairInt64 local_sum;
96+
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end;
97+
idx += kBlockSize) {
9498
local_sum += LoadGpair(node_histogram + idx);
9599
}
96100
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum);
97101
// Broadcast result from thread 0
98-
return {__shfl_sync(0xffffffff, local_sum.GetGrad(), 0),
99-
__shfl_sync(0xffffffff, local_sum.GetHess(), 0)};
102+
return {__shfl_sync(0xffffffff, local_sum.GetQuantisedGrad(), 0),
103+
__shfl_sync(0xffffffff, local_sum.GetQuantisedHess(), 0)};
100104
}
101105

102106
// Load using efficient 128 vector load instruction
103-
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairInt64 *ptr) {
107+
__device__ __forceinline__ GradientPairInt64 LoadGpair(const GradientPairInt64 *ptr) {
104108
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
105-
auto gpair_int = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
106-
static_assert(sizeof(decltype(gpair_int)) == sizeof(float4),
109+
auto gpair = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
110+
static_assert(sizeof(decltype(gpair)) == sizeof(float4),
107111
"Vector type size does not match gradient pair size.");
108-
return rounding.ToFloatingPoint(gpair_int);
112+
return gpair;
109113
}
110114

111115
__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
112116
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
113117
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
114-
GradientPairPrecise bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
115-
: GradientPairPrecise();
118+
GradientPairInt64 bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
119+
: GradientPairInt64();
116120
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
117121
// Whether the gradient of missing values is put to the left side.
118122
bool missing_left = true;
119123
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
120-
evaluator, missing_left)
124+
evaluator, missing_left, rounding)
121125
: kNullGain;
122-
123126
// Find thread with best gain
124127
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
125128
// This reduce result is only valid in thread 0
@@ -132,10 +135,10 @@ class EvaluateSplitAgent {
132135
int split_gidx = (scan_begin + threadIdx.x) - 1;
133136
float fvalue =
134137
split_gidx < static_cast<int>(gidx_begin) ? min_fvalue : feature_values[split_gidx];
135-
GradientPairPrecise left = missing_left ? bin + missing : bin;
136-
GradientPairPrecise right = parent_sum - left;
138+
GradientPairInt64 left = missing_left ? bin + missing : bin;
139+
GradientPairInt64 right = parent_sum - left;
137140
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
138-
false, param);
141+
false, param, rounding);
139142
}
140143
}
141144
}
@@ -145,12 +148,12 @@ class EvaluateSplitAgent {
145148
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
146149

147150
auto rest = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
148-
: GradientPairPrecise();
149-
GradientPairPrecise bin = parent_sum - rest - missing;
151+
: GradientPairInt64();
152+
GradientPairInt64 bin = parent_sum - rest - missing;
150153
// Whether the gradient of missing values is put to the left side.
151154
bool missing_left = true;
152155
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
153-
evaluator, missing_left)
156+
evaluator, missing_left, rounding)
154157
: kNullGain;
155158

156159
// Find thread with best gain
@@ -162,10 +165,10 @@ class EvaluateSplitAgent {
162165
if (threadIdx.x == best_thread) {
163166
int32_t split_gidx = (scan_begin + threadIdx.x);
164167
float fvalue = feature_values[split_gidx];
165-
GradientPairPrecise left = missing_left ? bin + missing : bin;
166-
GradientPairPrecise right = parent_sum - left;
168+
GradientPairInt64 left = missing_left ? bin + missing : bin;
169+
GradientPairInt64 right = parent_sum - left;
167170
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
168-
static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
171+
static_cast<bst_cat_t>(fvalue), fidx, left, right, param, rounding);
169172
}
170173
}
171174
}
@@ -174,11 +177,13 @@ class EvaluateSplitAgent {
174177
*/
175178
__device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
176179
bool missing_left, bst_bin_t it,
177-
GradientPairPrecise const &left_sum,
178-
GradientPairPrecise const &right_sum,
180+
GradientPairInt64 const &left_sum,
181+
GradientPairInt64 const &right_sum,
179182
DeviceSplitCandidate *__restrict__ best_split) {
180-
auto gain =
181-
thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;
183+
auto gain = thread_active
184+
? evaluator.CalcSplitGain(param, nidx, fidx, rounding.ToFloatingPoint(left_sum),
185+
rounding.ToFloatingPoint(right_sum))
186+
: kNullGain;
182187

183188
// Find thread with best gain
184189
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
@@ -191,7 +196,7 @@ class EvaluateSplitAgent {
191196
// index of best threshold inside a feature.
192197
auto best_thresh = it - gidx_begin;
193198
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
194-
right_sum, param);
199+
right_sum, param, rounding);
195200
}
196201
}
197202
/**
@@ -213,28 +218,28 @@ class EvaluateSplitAgent {
213218
bool thread_active = it < it_end;
214219

215220
auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
216-
: GradientPairPrecise();
221+
: GradientPairInt64();
217222
// No min value for cat feature, use inclusive scan.
218223
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
219-
GradientPairPrecise left_sum = parent_sum - right_sum;
224+
GradientPairInt64 left_sum = parent_sum - right_sum;
220225

221226
PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
222227
}
223228

224229
// backward
225230
it_begin = gidx_end - 1;
226231
it_end = it_begin - n_bins + 1;
227-
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset
232+
prefix_op = SumCallbackOp<GradientPairInt64>{}; // reset
228233

229234
for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
230235
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
231236
bool thread_active = it > it_end;
232237

233238
auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
234-
: GradientPairPrecise();
239+
: GradientPairInt64();
235240
// No min value for cat feature, use inclusive scan.
236241
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
237-
GradientPairPrecise right_sum = parent_sum - left_sum;
242+
GradientPairInt64 right_sum = parent_sum - left_sum;
238243

239244
PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
240245
}
@@ -399,22 +404,30 @@ void GPUHistEvaluator::EvaluateSplits(
399404
auto const input = d_inputs[i];
400405
auto &split = out_splits[i];
401406
// Subtract parent gain here
402-
// As it is constant, this is more efficient than doing it during every split evaluation
403-
float parent_gain = CalcGain(shared_inputs.param, input.parent_sum);
407+
// As it is constant, this is more efficient than doing it during every
408+
// split evaluation
409+
float parent_gain =
410+
CalcGain(shared_inputs.param,
411+
shared_inputs.rounding.ToFloatingPoint(input.parent_sum));
404412
split.loss_chg -= parent_gain;
405413
auto fidx = out_splits[i].findex;
406414

407415
if (split.is_cat) {
408416
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
409-
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
417+
device_cats_accessor.GetNodeCatStorage(input.nidx),
418+
&out_splits[i]);
410419
}
411420

412-
float base_weight = evaluator.CalcWeight(input.nidx, shared_inputs.param,
413-
GradStats{split.left_sum + split.right_sum});
414-
float left_weight =
415-
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.left_sum});
416-
float right_weight =
417-
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.right_sum});
421+
float base_weight =
422+
evaluator.CalcWeight(input.nidx, shared_inputs.param,
423+
shared_inputs.rounding.ToFloatingPoint(
424+
split.left_sum + split.right_sum));
425+
float left_weight = evaluator.CalcWeight(
426+
input.nidx, shared_inputs.param,
427+
shared_inputs.rounding.ToFloatingPoint(split.left_sum));
428+
float right_weight = evaluator.CalcWeight(
429+
input.nidx, shared_inputs.param,
430+
shared_inputs.rounding.ToFloatingPoint(split.right_sum));
418431

419432
d_entries[i] = GPUExpandEntry{input.nidx, input.depth, out_splits[i],
420433
base_weight, left_weight, right_weight};

src/tree/gpu_hist/evaluate_splits.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,20 @@ namespace tree {
2323
struct EvaluateSplitInputs {
2424
int nidx;
2525
int depth;
26-
GradientPairPrecise parent_sum;
26+
GradientPairInt64 parent_sum;
2727
common::Span<const bst_feature_t> feature_set;
2828
common::Span<const GradientPairInt64> gradient_histogram;
2929
};
3030

3131
// Inputs necessary for all nodes
3232
struct EvaluateSplitSharedInputs {
3333
GPUTrainingParam param;
34-
GradientQuantizer rounding;
34+
GradientQuantiser rounding;
3535
common::Span<FeatureType const> feature_types;
3636
common::Span<const uint32_t> feature_segments;
3737
common::Span<const float> feature_values;
3838
common::Span<const float> min_fvalue;
39+
bool is_dense;
3940
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
4041
__device__ auto FeatureBins(bst_feature_t fidx) const {
4142
return feature_segments[fidx + 1] - feature_segments[fidx];

src/tree/gpu_hist/expand_entry.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct GPUExpandEntry {
2727
left_weight{left}, right_weight{right} {}
2828
bool IsValid(const TrainParam& param, int num_leaves) const {
2929
if (split.loss_chg <= kRtEps) return false;
30-
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
30+
if (split.left_sum.GetQuantisedHess() == 0 || split.right_sum.GetQuantisedHess() == 0) {
3131
return false;
3232
}
3333
if (split.loss_chg < param.min_split_loss) {

src/tree/gpu_hist/histogram.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
7272
}
7373
};
7474

75-
GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
75+
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
7676
using GradientSumT = GradientPairPrecise;
7777
using T = typename GradientSumT::ValueT;
7878
dh::XGBCachingDeviceAllocator<char> alloc;
@@ -153,14 +153,14 @@ class HistogramAgent {
153153
const EllpackDeviceAccessor& matrix_;
154154
const int feature_stride_;
155155
const std::size_t n_elements_;
156-
const GradientQuantizer& rounding_;
156+
const GradientQuantiser& rounding_;
157157

158158
public:
159159
__device__ HistogramAgent(GradientPairInt64* smem_arr,
160160
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
161161
const EllpackDeviceAccessor& matrix,
162162
common::Span<const RowPartitioner::RowIndexT> d_ridx,
163-
const GradientQuantizer& rounding, const GradientPair* d_gpair)
163+
const GradientQuantiser& rounding, const GradientPair* d_gpair)
164164
: smem_arr_(smem_arr),
165165
d_node_hist_(d_node_hist),
166166
d_ridx_(d_ridx.data()),
@@ -254,7 +254,7 @@ __global__ void __launch_bounds__(kBlockThreads)
254254
common::Span<const RowPartitioner::RowIndexT> d_ridx,
255255
GradientPairInt64* __restrict__ d_node_hist,
256256
const GradientPair* __restrict__ d_gpair,
257-
GradientQuantizer const rounding) {
257+
GradientQuantiser const rounding) {
258258
extern __shared__ char smem[];
259259
const FeatureGroup group = feature_groups[blockIdx.y];
260260
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
@@ -272,7 +272,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
272272
common::Span<GradientPair const> gpair,
273273
common::Span<const uint32_t> d_ridx,
274274
common::Span<GradientPairInt64> histogram,
275-
GradientQuantizer rounding, bool force_global_memory) {
275+
GradientQuantiser rounding, bool force_global_memory) {
276276
// decide whether to use shared memory
277277
int device = 0;
278278
dh::safe_cuda(cudaGetDevice(&device));

0 commit comments

Comments
 (0)