Skip to content

Commit 06487d3

Browse files
authored
[backport] Fix GPU categorical split memory allocation. (dmlc#9529) (dmlc#9535)
1 parent e50ccc4 commit 06487d3

File tree

5 files changed

+67
-64
lines changed

5 files changed

+67
-64
lines changed

src/common/categorical.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) {
5252
*
5353
* Go to left if it's NOT the matching category, which matches one-hot encoding.
5454
*/
55-
inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat) {
55+
inline XGBOOST_DEVICE bool Decision(common::Span<CatBitField::value_type const> cats, float cat) {
5656
KCatBitField const s_cats(cats);
5757
if (XGBOOST_EXPECT(InvalidCat(cat), false)) {
5858
return true;

src/tree/gpu_hist/evaluate_splits.cu

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
/*!
2-
* Copyright 2020-2022 by XGBoost Contributors
1+
/**
2+
* Copyright 2020-2023, XGBoost Contributors
33
*/
44
#include <algorithm> // std::max
55
#include <vector>
@@ -11,9 +11,7 @@
1111
#include "evaluate_splits.cuh"
1212
#include "expand_entry.cuh"
1313

14-
namespace xgboost {
15-
namespace tree {
16-
14+
namespace xgboost::tree {
1715
// With constraints
1816
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
1917
const GradientPairInt64 &missing,
@@ -315,11 +313,11 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
315313
common::Span<common::CatBitField::value_type> out,
316314
DeviceSplitCandidate *p_out_split) {
317315
auto &out_split = *p_out_split;
318-
out_split.split_cats = common::CatBitField{out};
316+
auto out_cats = common::CatBitField{out};
319317

320318
// Simple case for one hot split
321319
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
322-
out_split.split_cats.Set(common::AsCat(out_split.thresh));
320+
out_cats.Set(common::AsCat(out_split.thresh));
323321
return;
324322
}
325323

@@ -339,7 +337,7 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu
339337
assert(partition > 0 && "Invalid partition.");
340338
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
341339
auto cat = shared_inputs.feature_values[c - node_offset];
342-
out_split.SetCat(cat);
340+
out_cats.Set(common::AsCat(cat));
343341
});
344342
}
345343

@@ -427,8 +425,7 @@ void GPUHistEvaluator::EvaluateSplits(
427425

428426
if (split.is_cat) {
429427
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
430-
device_cats_accessor.GetNodeCatStorage(input.nidx),
431-
&out_splits[i]);
428+
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
432429
}
433430

434431
float base_weight =
@@ -460,6 +457,4 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(
460457
cudaMemcpyDeviceToHost));
461458
return root_entry;
462459
}
463-
464-
} // namespace tree
465-
} // namespace xgboost
460+
} // namespace xgboost::tree

src/tree/gpu_hist/evaluate_splits.cuh

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ struct EvaluateSplitSharedInputs {
3737
common::Span<const float> feature_values;
3838
common::Span<const float> min_fvalue;
3939
bool is_dense;
40-
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
41-
__device__ auto FeatureBins(bst_feature_t fidx) const {
40+
[[nodiscard]] XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
41+
[[nodiscard]] __device__ std::uint32_t FeatureBins(bst_feature_t fidx) const {
4242
return feature_segments[fidx + 1] - feature_segments[fidx];
4343
}
4444
};
@@ -102,7 +102,7 @@ class GPUHistEvaluator {
102102
}
103103

104104
/**
105-
* \brief Get device category storage of nidx for internal calculation.
105+
* @brief Get device category storage of nidx for internal calculation.
106106
*/
107107
auto DeviceCatStorage(const std::vector<bst_node_t> &nidx) {
108108
if (!has_categoricals_) return CatAccessor{};
@@ -117,8 +117,8 @@ class GPUHistEvaluator {
117117
/**
118118
* \brief Get sorted index storage based on the left node of inputs.
119119
*/
120-
auto SortedIdx(int num_nodes, bst_feature_t total_bins) {
121-
if(!need_sort_histogram_) return common::Span<bst_feature_t>();
120+
auto SortedIdx(int num_nodes, bst_bin_t total_bins) {
121+
if (!need_sort_histogram_) return common::Span<bst_feature_t>{};
122122
cat_sorted_idx_.resize(num_nodes * total_bins);
123123
return dh::ToSpan(cat_sorted_idx_);
124124
}
@@ -142,12 +142,22 @@ class GPUHistEvaluator {
142142
* \brief Get host category storage for nidx. Different from the internal version, this
143143
* returns strictly 1 node.
144144
*/
145-
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
145+
[[nodiscard]] common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
146146
copy_stream_.View().Sync();
147147
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
148148
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
149149
return cats_out;
150150
}
151+
152+
[[nodiscard]] auto GetDeviceNodeCats(bst_node_t nidx) {
153+
copy_stream_.View().Sync();
154+
if (has_categoricals_) {
155+
CatAccessor accessor = {dh::ToSpan(split_cats_), node_categorical_storage_size_};
156+
return common::KCatBitField{accessor.GetNodeCatStorage(nidx)};
157+
} else {
158+
return common::KCatBitField{};
159+
}
160+
}
151161
/**
152162
* \brief Add a split to the internal tree evaluator.
153163
*/

src/tree/updater_gpu_common.cuh

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,13 @@ struct DeviceSplitCandidate {
6464
// split.
6565
bst_cat_t thresh{-1};
6666

67-
common::CatBitField split_cats;
6867
bool is_cat { false };
6968

7069
GradientPairInt64 left_sum;
7170
GradientPairInt64 right_sum;
7271

7372
XGBOOST_DEVICE DeviceSplitCandidate() {} // NOLINT
7473

75-
template <typename T>
76-
XGBOOST_DEVICE void SetCat(T c) {
77-
this->split_cats.Set(common::AsCat(c));
78-
fvalue = std::max(this->fvalue, static_cast<float>(c));
79-
}
80-
8174
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in,
8275
int findex_in, GradientPairInt64 left_sum_in,
8376
GradientPairInt64 right_sum_in, bool cat,
@@ -100,22 +93,23 @@ struct DeviceSplitCandidate {
10093
*/
10194
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
10295
bst_feature_t findex_in, GradientPairInt64 left_sum_in,
103-
GradientPairInt64 right_sum_in, GPUTrainingParam const& param, const GradientQuantiser& quantiser) {
104-
if (loss_chg_in > loss_chg &&
105-
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
106-
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
107-
loss_chg = loss_chg_in;
108-
dir = dir_in;
109-
fvalue = std::numeric_limits<float>::quiet_NaN();
110-
thresh = thresh_in;
111-
is_cat = true;
112-
left_sum = left_sum_in;
113-
right_sum = right_sum_in;
114-
findex = findex_in;
115-
}
96+
GradientPairInt64 right_sum_in, GPUTrainingParam const& param,
97+
const GradientQuantiser& quantiser) {
98+
if (loss_chg_in > loss_chg &&
99+
quantiser.ToFloatingPoint(left_sum_in).GetHess() >= param.min_child_weight &&
100+
quantiser.ToFloatingPoint(right_sum_in).GetHess() >= param.min_child_weight) {
101+
loss_chg = loss_chg_in;
102+
dir = dir_in;
103+
fvalue = std::numeric_limits<float>::quiet_NaN();
104+
thresh = thresh_in;
105+
is_cat = true;
106+
left_sum = left_sum_in;
107+
right_sum = right_sum_in;
108+
findex = findex_in;
109+
}
116110
}
117111

118-
XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
112+
[[nodiscard]] XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }
119113

120114
friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
121115
os << "loss_chg:" << c.loss_chg << ", "

src/tree/updater_gpu_hist.cu

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
#include <algorithm>
99
#include <cmath>
10-
#include <limits>
11-
#include <memory>
12-
#include <utility>
10+
#include <cstddef> // for size_t
11+
#include <memory> // for unique_ptr, make_unique
12+
#include <utility> // for move
1313
#include <vector>
1414

1515
#include "../collective/communicator-inl.cuh"
@@ -216,9 +216,9 @@ struct GPUHistMakerDevice {
216216
void InitFeatureGroupsOnce() {
217217
if (!feature_groups) {
218218
CHECK(page);
219-
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
220-
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
221-
sizeof(GradientPairPrecise)));
219+
feature_groups = std::make_unique<FeatureGroups>(page->Cuts(), page->is_dense,
220+
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
221+
sizeof(GradientPairPrecise));
222222
}
223223
}
224224

@@ -244,10 +244,10 @@ struct GPUHistMakerDevice {
244244

245245
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
246246

247-
quantiser.reset(new GradientQuantiser(this->gpair));
247+
quantiser = std::make_unique<GradientQuantiser>(this->gpair);
248248

249249
row_partitioner.reset(); // Release the device memory first before reallocating
250-
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
250+
row_partitioner = std::make_unique<RowPartitioner>(ctx_->gpu_id, sample.sample_rows);
251251

252252
// Init histogram
253253
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
@@ -294,7 +294,7 @@ struct GPUHistMakerDevice {
294294
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
295295
// Store the feature set ptrs so they dont go out of scope before the kernel is called
296296
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
297-
for (size_t i = 0; i < candidates.size(); i++) {
297+
for (std::size_t i = 0; i < candidates.size(); i++) {
298298
auto candidate = candidates.at(i);
299299
int left_nidx = tree[candidate.nid].LeftChild();
300300
int right_nidx = tree[candidate.nid].RightChild();
@@ -327,14 +327,13 @@ struct GPUHistMakerDevice {
327327
d_node_inputs.data().get(), h_node_inputs.data(),
328328
h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault));
329329

330-
this->evaluator_.EvaluateSplits(nidx, max_active_features,
331-
dh::ToSpan(d_node_inputs), shared_inputs,
332-
dh::ToSpan(entries));
330+
this->evaluator_.EvaluateSplits(nidx, max_active_features, dh::ToSpan(d_node_inputs),
331+
shared_inputs, dh::ToSpan(entries));
333332
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(),
334333
entries.data().get(), sizeof(GPUExpandEntry) * entries.size(),
335334
cudaMemcpyDeviceToHost));
336335
dh::DefaultStream().Sync();
337-
}
336+
}
338337

339338
void BuildHist(int nidx) {
340339
auto d_node_hist = hist.GetNodeHistogram(nidx);
@@ -366,31 +365,37 @@ struct GPUHistMakerDevice {
366365
struct NodeSplitData {
367366
RegTree::Node split_node;
368367
FeatureType split_type;
369-
common::CatBitField node_cats;
368+
common::KCatBitField node_cats;
370369
};
371370

372-
void UpdatePosition(const std::vector<GPUExpandEntry>& candidates, RegTree* p_tree) {
373-
if (candidates.empty()) return;
374-
std::vector<int> nidx(candidates.size());
375-
std::vector<int> left_nidx(candidates.size());
376-
std::vector<int> right_nidx(candidates.size());
371+
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) {
372+
if (candidates.empty()) {
373+
return;
374+
}
375+
376+
std::vector<bst_node_t> nidx(candidates.size());
377+
std::vector<bst_node_t> left_nidx(candidates.size());
378+
std::vector<bst_node_t> right_nidx(candidates.size());
377379
std::vector<NodeSplitData> split_data(candidates.size());
380+
378381
for (size_t i = 0; i < candidates.size(); i++) {
379-
auto& e = candidates[i];
382+
auto const& e = candidates[i];
380383
RegTree::Node split_node = (*p_tree)[e.nid];
381384
auto split_type = p_tree->NodeSplitType(e.nid);
382385
nidx.at(i) = e.nid;
383386
left_nidx.at(i) = split_node.LeftChild();
384387
right_nidx.at(i) = split_node.RightChild();
385-
split_data.at(i) = NodeSplitData{split_node, split_type, e.split.split_cats};
388+
split_data.at(i) = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
389+
390+
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
386391
}
387392

388393
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
389394
row_partitioner->UpdatePositionBatch(
390395
nidx, left_nidx, right_nidx, split_data,
391396
[=] __device__(bst_uint ridx, const NodeSplitData& data) {
392397
// given a row index, returns the node id it belongs to
393-
bst_float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
398+
float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
394399
// Missing value
395400
bool go_left = true;
396401
if (isnan(cut_value)) {
@@ -620,7 +625,6 @@ struct GPUHistMakerDevice {
620625
CHECK(common::CheckNAN(candidate.split.fvalue));
621626
std::vector<common::CatBitField::value_type> split_cats;
622627

623-
CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
624628
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
625629
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
626630
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);

0 commit comments

Comments
 (0)