Skip to content

Commit e1a2c1b

Browse files
authored
[EM] Merge GPU partitioning with histogram building. (dmlc#10766)
- Stop concatenating pages if there's no subsampling. - Use a single iteration for histogram build and partitioning.
1 parent 98ac153 commit e1a2c1b

File tree

7 files changed

+118
-159
lines changed

7 files changed

+118
-159
lines changed

python-package/xgboost/testing/updater.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@ def check_extmem_qdm(
222222
Xy = xgb.QuantileDMatrix(X, y, weight=w)
223223
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
224224

225-
cut_it = Xy_it.get_quantile_cut()
226-
cut = Xy.get_quantile_cut()
227-
np.testing.assert_allclose(cut_it[0], cut[0])
228-
np.testing.assert_allclose(cut_it[1], cut[1])
225+
if device == "cpu":
226+
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
227+
cut_it = Xy_it.get_quantile_cut()
228+
cut = Xy.get_quantile_cut()
229+
np.testing.assert_allclose(cut_it[0], cut[0])
230+
np.testing.assert_allclose(cut_it[1], cut[1])
229231

230232
predt_it = booster_it.predict(Xy_it)
231233
predt = booster.predict(Xy)

src/tree/gpu_hist/gradient_based_sampler.cu

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -158,28 +158,10 @@ GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair
158158
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
159159
: batch_param_{std::move(batch_param)} {}
160160

161-
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
161+
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const*,
162162
common::Span<GradientPair> gpair,
163163
DMatrix* p_fmat) {
164-
std::shared_ptr<EllpackPage> new_page;
165-
if (!page_concatenated_) {
166-
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
167-
bst_idx_t offset = 0;
168-
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
169-
auto page = batch.Impl();
170-
if (!new_page) {
171-
new_page = std::make_shared<EllpackPage>();
172-
*new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense,
173-
page->row_stride, p_fmat->Info().num_row_);
174-
}
175-
bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset);
176-
offset += num_elements;
177-
}
178-
page_concatenated_ = true;
179-
this->p_fmat_new_ =
180-
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
181-
}
182-
return {this->p_fmat_new_.get(), gpair};
164+
return {p_fmat, gpair};
183165
}
184166

185167
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)

src/tree/gpu_hist/gradient_based_sampler.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
4646

4747
private:
4848
BatchParam batch_param_;
49-
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
50-
bool page_concatenated_{false};
5149
};
5250

5351
/*! \brief Uniform sampling in in-memory mode. */

src/tree/gpu_hist/row_partitioner.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
2222
NodePositionInfo{Segment{0, static_cast<cuda_impl::RowIndexT>(n_samples)}});
2323

2424
thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);
25+
26+
// Pre-allocate some host memory
27+
this->pinned_.GetSpan<std::int32_t>(1 << 11);
28+
this->pinned2_.GetSpan<std::int32_t>(1 << 13);
2529
}
2630

2731
RowPartitioner::~RowPartitioner() = default;

src/tree/updater_gpu_hist.cu

Lines changed: 93 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ struct GPUHistMakerDevice {
200200

201201
// Reset values for each update iteration
202202
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
203+
this->monitor.Start(__func__);
203204
auto const& info = p_fmat->Info();
204205
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
205206
param.colsample_bynode, param.colsample_bylevel,
@@ -252,7 +253,7 @@ struct GPUHistMakerDevice {
252253
this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()),
253254
feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(),
254255
false);
255-
256+
this->monitor.Stop(__func__);
256257
return p_fmat;
257258
}
258259

@@ -346,6 +347,38 @@ struct GPUHistMakerDevice {
346347
monitor.Stop(__func__);
347348
}
348349

350+
void ReduceHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
351+
std::vector<bst_node_t> const& build_nidx,
352+
std::vector<bst_node_t> const& subtraction_nidx) {
353+
if (candidates.empty()) {
354+
return;
355+
}
356+
this->monitor.Start(__func__);
357+
358+
// Reduce all in one go
359+
// This gives much better latency in a distributed setting when processing a large batch
360+
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size());
361+
// Perform subtraction for sibiling nodes
362+
auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx);
363+
if (need_build.empty()) {
364+
this->monitor.Stop(__func__);
365+
return;
366+
}
367+
368+
// Build the nodes that can not obtain the histogram using subtraction. This is the slow path.
369+
std::int32_t k = 0;
370+
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
371+
for (auto nidx : need_build) {
372+
this->BuildHist(page, k, nidx);
373+
}
374+
++k;
375+
}
376+
for (auto nidx : need_build) {
377+
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), nidx, 1);
378+
}
379+
this->monitor.Stop(__func__);
380+
}
381+
349382
void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
350383
std::vector<NodeSplitData> const& split_data,
351384
std::vector<bst_node_t> const& nidx,
@@ -434,56 +467,74 @@ struct GPUHistMakerDevice {
434467
}
435468
};
436469

437-
void UpdatePosition(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
438-
RegTree* p_tree) {
439-
if (candidates.empty()) {
470+
// Update position and build histogram.
471+
void PartitionAndBuildHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& expand_set,
472+
std::vector<GPUExpandEntry> const& candidates, RegTree const* p_tree) {
473+
if (expand_set.empty()) {
440474
return;
441475
}
442-
443476
monitor.Start(__func__);
477+
CHECK_LE(candidates.size(), expand_set.size());
444478

445-
auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates);
479+
// Update all the nodes if working with external memory, this saves us from working
480+
// with the finalize position call, which adds an additional iteration and requires
481+
// special handling for row index.
482+
bool const is_single_block = p_fmat->SingleColBlock();
446483

447-
for (size_t i = 0; i < candidates.size(); i++) {
448-
auto const& e = candidates[i];
449-
RegTree::Node const& split_node = (*p_tree)[e.nid];
450-
auto split_type = p_tree->NodeSplitType(e.nid);
451-
nidx[i] = e.nid;
452-
left_nidx[i] = split_node.LeftChild();
453-
right_nidx[i] = split_node.RightChild();
454-
split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
484+
// Prepare for update partition
485+
auto [nidx, left_nidx, right_nidx, split_data] =
486+
this->CreatePartitionNodes(p_tree, is_single_block ? candidates : expand_set);
455487

456-
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
457-
}
488+
// Prepare for build hist
489+
std::vector<bst_node_t> build_nidx(candidates.size());
490+
std::vector<bst_node_t> subtraction_nidx(candidates.size());
491+
auto prefetch_copy =
492+
AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx);
458493

459-
CHECK_EQ(p_fmat->NumBatches(), 1);
460-
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
494+
this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx);
495+
496+
monitor.Start("Partition-BuildHist");
497+
498+
std::int32_t k{0};
499+
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
461500
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
501+
auto go_left = GoLeftOp{d_matrix};
462502

503+
// Partition histogram.
504+
monitor.Start("UpdatePositionBatch");
463505
if (p_fmat->Info().IsColumnSplit()) {
464506
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
465-
monitor.Stop(__func__);
466-
return;
507+
} else {
508+
partitioners_.at(k)->UpdatePositionBatch(
509+
nidx, left_nidx, right_nidx, split_data,
510+
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
511+
const NodeSplitData& data) { return go_left(ridx, data); });
467512
}
468-
auto go_left = GoLeftOp{d_matrix};
469-
partitioners_.front()->UpdatePositionBatch(
470-
nidx, left_nidx, right_nidx, split_data,
471-
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
472-
const NodeSplitData& data) { return go_left(ridx, data); });
513+
monitor.Stop("UpdatePositionBatch");
514+
515+
for (auto nidx : build_nidx) {
516+
this->BuildHist(page, k, nidx);
517+
}
518+
519+
++k;
473520
}
474521

522+
monitor.Stop("Partition-BuildHist");
523+
524+
this->ReduceHist(p_fmat, candidates, build_nidx, subtraction_nidx);
525+
475526
monitor.Stop(__func__);
476527
}
477528

478529
// After tree update is finished, update the position of all training
479530
// instances to their final leaf. This information is used later to update the
480531
// prediction cache
481-
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples,
532+
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task,
482533
HostDeviceVector<bst_node_t>* p_out_position) {
483534
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
484535
LOG(FATAL) << "Current objective function can not be used with external memory.";
485536
}
486-
if (p_fmat->Info().num_row_ != n_samples) {
537+
if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) {
487538
// External memory with concatenation. Not supported.
488539
p_out_position->Resize(0);
489540
positions_.clear();
@@ -577,60 +628,6 @@ struct GPUHistMakerDevice {
577628
return true;
578629
}
579630

580-
/**
581-
* \brief Build GPU local histograms for the left and right child of some parent node
582-
*/
583-
void BuildHistLeftRight(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
584-
const RegTree& tree) {
585-
if (candidates.empty()) {
586-
return;
587-
}
588-
this->monitor.Start(__func__);
589-
// Some nodes we will manually compute histograms
590-
// others we will do by subtraction
591-
std::vector<bst_node_t> hist_nidx(candidates.size());
592-
std::vector<bst_node_t> subtraction_nidx(candidates.size());
593-
auto prefetch_copy =
594-
AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx);
595-
596-
std::vector<int> all_new = hist_nidx;
597-
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
598-
// Allocate the histograms
599-
// Guaranteed contiguous memory
600-
histogram_.AllocateHistograms(ctx_, all_new);
601-
602-
std::int32_t k = 0;
603-
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
604-
for (auto nidx : hist_nidx) {
605-
this->BuildHist(page, k, nidx);
606-
}
607-
++k;
608-
}
609-
610-
// Reduce all in one go
611-
// This gives much better latency in a distributed setting
612-
// when processing a large batch
613-
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());
614-
615-
for (size_t i = 0; i < subtraction_nidx.size(); i++) {
616-
auto build_hist_nidx = hist_nidx.at(i);
617-
auto subtraction_trick_nidx = subtraction_nidx.at(i);
618-
auto parent_nidx = candidates.at(i).nid;
619-
620-
if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx,
621-
subtraction_trick_nidx)) {
622-
// Calculate other histogram manually
623-
std::int32_t k = 0;
624-
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
625-
this->BuildHist(page, k, subtraction_trick_nidx);
626-
++k;
627-
}
628-
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1);
629-
}
630-
}
631-
this->monitor.Stop(__func__);
632-
}
633-
634631
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
635632
RegTree& tree = *p_tree;
636633

@@ -681,8 +678,9 @@ struct GPUHistMakerDevice {
681678
}
682679

683680
GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) {
684-
constexpr bst_node_t kRootNIdx = 0;
685-
dh::XGBCachingDeviceAllocator<char> alloc;
681+
this->monitor.Start(__func__);
682+
683+
constexpr bst_node_t kRootNIdx = RegTree::kRoot;
686684
auto quantiser = *this->quantiser;
687685
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
688686
dh::tbegin(gpair),
@@ -697,6 +695,7 @@ struct GPUHistMakerDevice {
697695

698696
histogram_.AllocateHistograms(ctx_, {kRootNIdx});
699697
std::int32_t k = 0;
698+
CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size());
700699
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
701700
this->BuildHist(page, k, kRootNIdx);
702701
++k;
@@ -712,25 +711,18 @@ struct GPUHistMakerDevice {
712711

713712
// Generate first split
714713
auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised);
714+
715+
this->monitor.Stop(__func__);
715716
return root_entry;
716717
}
717718

718719
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
719720
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
720-
bool const is_single_block = p_fmat->SingleColBlock();
721-
bst_idx_t const n_samples = p_fmat->Info().num_row_;
722-
723-
auto& tree = *p_tree;
724721
// Process maximum 32 nodes at a time
725722
Driver<GPUExpandEntry> driver(param, 32);
726723

727-
monitor.Start("Reset");
728724
p_fmat = this->Reset(gpair_all, p_fmat);
729-
monitor.Stop("Reset");
730-
731-
monitor.Start("InitRoot");
732725
driver.Push({this->InitRoot(p_fmat, p_tree)});
733-
monitor.Stop("InitRoot");
734726

735727
// The set of leaves that can be expanded asynchronously
736728
auto expand_set = driver.Pop();
@@ -740,20 +732,17 @@ struct GPUHistMakerDevice {
740732
}
741733
// Get the candidates we are allowed to expand further
742734
// e.g. We do not bother further processing nodes whose children are beyond max depth
743-
std::vector<GPUExpandEntry> filtered_expand_set;
744-
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
745-
[&](const auto& e) { return driver.IsChildValid(e); });
735+
std::vector<GPUExpandEntry> valid_candidates;
736+
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(valid_candidates),
737+
[&](auto const& e) { return driver.IsChildValid(e); });
746738

739+
// Allocaate children nodes.
747740
auto new_candidates =
748-
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry{});
749-
// Update all the nodes if working with external memory, this saves us from working
750-
// with the finalize position call, which adds an additional iteration and requires
751-
// special handling for row index.
752-
this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree);
741+
pinned.GetSpan<GPUExpandEntry>(valid_candidates.size() * 2, GPUExpandEntry());
753742

754-
this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree);
743+
this->PartitionAndBuildHist(p_fmat, expand_set, valid_candidates, p_tree);
755744

756-
this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates);
745+
this->EvaluateSplits(p_fmat, valid_candidates, *p_tree, new_candidates);
757746
dh::DefaultStream().Sync();
758747

759748
driver.Push(new_candidates.begin(), new_candidates.end());
@@ -764,10 +753,10 @@ struct GPUHistMakerDevice {
764753
// be spliable before evaluation but invalid after evaluation as we have more
765754
// restrictions like min loss change after evalaution. Therefore, the check condition
766755
// is greater than or equal to.
767-
if (is_single_block) {
756+
if (p_fmat->SingleColBlock()) {
768757
CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes());
769758
}
770-
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
759+
this->FinalisePosition(p_fmat, p_tree, *task, p_out_position);
771760
}
772761
};
773762

0 commit comments

Comments
 (0)