Skip to content

Commit 54029a5

Browse files
authored
Bound the size of the histogram cache. (dmlc#9440)
- A new histogram collection with a limit in size. - Unify histogram building logic between hist, multi-hist, and approx.
1 parent 5bd163a commit 54029a5

27 files changed

+989
-560
lines changed

R-package/src/Makevars.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ OBJECTS= \
6969
$(PKGROOT)/src/tree/updater_refresh.o \
7070
$(PKGROOT)/src/tree/updater_sync.o \
7171
$(PKGROOT)/src/tree/hist/param.o \
72+
$(PKGROOT)/src/tree/hist/histogram.o \
7273
$(PKGROOT)/src/linear/linear_updater.o \
7374
$(PKGROOT)/src/linear/updater_coordinate.o \
7475
$(PKGROOT)/src/linear/updater_shotgun.o \

R-package/src/Makevars.win

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ OBJECTS= \
6969
$(PKGROOT)/src/tree/updater_refresh.o \
7070
$(PKGROOT)/src/tree/updater_sync.o \
7171
$(PKGROOT)/src/tree/hist/param.o \
72+
$(PKGROOT)/src/tree/hist/histogram.o \
7273
$(PKGROOT)/src/linear/linear_updater.o \
7374
$(PKGROOT)/src/linear/updater_coordinate.o \
7475
$(PKGROOT)/src/linear/updater_shotgun.o \

include/xgboost/base.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ namespace xgboost {
9191

9292
/*! \brief unsigned integer type used for feature index. */
9393
using bst_uint = uint32_t; // NOLINT
94-
/*! \brief integer type. */
95-
using bst_int = int32_t; // NOLINT
9694
/*! \brief unsigned long integers */
9795
using bst_ulong = uint64_t; // NOLINT
9896
/*! \brief float type, used for storing statistics */
@@ -138,9 +136,9 @@ namespace detail {
138136
template <typename T>
139137
class GradientPairInternal {
140138
/*! \brief gradient statistics */
141-
T grad_;
139+
T grad_{0};
142140
/*! \brief second order gradient statistics */
143-
T hess_;
141+
T hess_{0};
144142

145143
XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; }
146144
XGBOOST_DEVICE void SetHess(T h) { hess_ = h; }
@@ -157,7 +155,7 @@ class GradientPairInternal {
157155
a += b;
158156
}
159157

160-
XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {}
158+
GradientPairInternal() = default;
161159

162160
XGBOOST_DEVICE GradientPairInternal(T grad, T hess) {
163161
SetGrad(grad);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Tests related to the `DataIter` interface."""
2+
import numpy as np
3+
4+
import xgboost
5+
from xgboost import testing as tm
6+
7+
8+
def run_mixed_sparsity(device: str) -> None:
9+
"""Check QDM with mixed batches."""
10+
X_0, y_0, _ = tm.make_regression(128, 16, False)
11+
if device.startswith("cuda"):
12+
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
13+
else:
14+
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, False)
15+
X_2, y_2 = tm.make_sparse_regression(512, 16, 0.9, True)
16+
X = [X_0, X_1, X_2]
17+
y = [y_0, y_1, y_2]
18+
19+
if device.startswith("cuda"):
20+
import cupy as cp # pylint: disable=import-error
21+
22+
X = [cp.array(batch) for batch in X]
23+
24+
it = tm.IteratorForTest(X, y, None, None)
25+
Xy_0 = xgboost.QuantileDMatrix(it)
26+
27+
X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
28+
X = [X_0, X_1, X_2]
29+
y = [y_0, y_1, y_2]
30+
X_arr = np.concatenate(X, axis=0)
31+
y_arr = np.concatenate(y, axis=0)
32+
Xy_1 = xgboost.QuantileDMatrix(X_arr, y_arr)
33+
34+
assert tm.predictor_equal(Xy_0, Xy_1)

python-package/xgboost/testing/params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
4242
)
4343

44+
hist_cache_strategy = strategies.fixed_dictionaries(
45+
{"internal_max_cached_hist_node": strategies.sampled_from([1, 4, 1024, 2**31])}
46+
)
47+
4448
hist_multi_parameter_strategy = strategies.fixed_dictionaries(
4549
{
4650
"max_depth": strategies.integers(1, 11),

src/common/hist_util.cc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,6 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
6767
return out;
6868
}
6969

70-
/*!
71-
* \brief fill a histogram by zeros in range [begin, end)
72-
*/
73-
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
74-
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
75-
std::fill(hist.begin() + begin, hist.begin() + end, xgboost::GradientPairPrecise());
76-
#else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
77-
memset(hist.data() + begin, '\0', (end - begin) * sizeof(xgboost::GradientPairPrecise));
78-
#endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
79-
}
80-
8170
/*!
8271
* \brief Increment hist as dst += add in range [begin, end)
8372
*/

src/common/hist_util.h

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,6 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t
364364
using GHistRow = Span<xgboost::GradientPairPrecise>;
365365
using ConstGHistRow = Span<xgboost::GradientPairPrecise const>;
366366

367-
/*!
368-
* \brief fill a histogram by zeros
369-
*/
370-
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end);
371-
372367
/*!
373368
* \brief Increment hist as dst += add in range [begin, end)
374369
*/
@@ -395,12 +390,7 @@ class HistCollection {
395390
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
396391
const size_t id = row_ptr_.at(nid);
397392
CHECK_NE(id, kMax);
398-
GradientPairPrecise* ptr = nullptr;
399-
if (contiguous_allocation_) {
400-
ptr = const_cast<GradientPairPrecise*>(data_[0].data() + nbins_*id);
401-
} else {
402-
ptr = const_cast<GradientPairPrecise*>(data_[id].data());
403-
}
393+
GradientPairPrecise* ptr = const_cast<GradientPairPrecise*>(data_[id].data());
404394
return {ptr, nbins_};
405395
}
406396

@@ -445,24 +435,12 @@ class HistCollection {
445435
data_[row_ptr_[nid]].resize(nbins_, {0, 0});
446436
}
447437
}
448-
// allocate common buffer contiguously for all nodes, need for single Allreduce call
449-
void AllocateAllData() {
450-
const size_t new_size = nbins_*data_.size();
451-
contiguous_allocation_ = true;
452-
if (data_[0].size() != new_size) {
453-
data_[0].resize(new_size);
454-
}
455-
}
456-
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }
457438

458439
private:
459440
/*! \brief number of all bins over all features */
460441
uint32_t nbins_ = 0;
461442
/*! \brief amount of active nodes in hist collection */
462443
uint32_t n_nodes_added_ = 0;
463-
/*! \brief flag to identify contiguous memory allocation */
464-
bool contiguous_allocation_ = false;
465-
466444
std::vector<std::vector<GradientPairPrecise>> data_;
467445

468446
/*! \brief row_ptr_[nid] locates bin for histogram of node nid */
@@ -518,7 +496,7 @@ class ParallelGHistBuilder {
518496
GHistRow hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx];
519497

520498
if (!hist_was_used_[tid * nodes_ + nid]) {
521-
InitilizeHistByZeroes(hist, 0, hist.size());
499+
std::fill_n(hist.data(), hist.size(), GradientPairPrecise{});
522500
hist_was_used_[tid * nodes_ + nid] = static_cast<int>(true);
523501
}
524502

@@ -548,7 +526,7 @@ class ParallelGHistBuilder {
548526
if (!is_updated) {
549527
// In distributed mode - some tree nodes can be empty on local machines,
550528
// So we need just set local hist by zeros in this case
551-
InitilizeHistByZeroes(dst, begin, end);
529+
std::fill(dst.data() + begin, dst.data() + end, GradientPairPrecise{});
552530
}
553531
}
554532

src/common/threading_utils.h

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
#include <dmlc/common.h>
88
#include <dmlc/omp.h>
99

10-
#include <algorithm>
11-
#include <cstdint> // for int32_t
12-
#include <cstdlib> // for malloc, free
13-
#include <limits>
10+
#include <algorithm> // for min
11+
#include <cstddef> // for size_t
12+
#include <cstdint> // for int32_t
13+
#include <cstdlib> // for malloc, free
14+
#include <functional> // for function
1415
#include <new> // for bad_alloc
15-
#include <type_traits> // for is_signed
16-
#include <vector>
16+
#include <type_traits> // for is_signed, conditional_t
17+
#include <vector> // for vector
1718

1819
#include "xgboost/logging.h"
1920

@@ -25,6 +26,8 @@ inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
2526

2627
// MSVC doesn't implement the thread limit.
2728
#if defined(_OPENMP) && defined(_MSC_VER)
29+
#include <limits>
30+
2831
extern "C" {
2932
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
3033
}
@@ -84,8 +87,8 @@ class BlockedSpace2d {
8487
// dim1 - size of the first dimension in the space
8588
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
8689
// grain_size - max size of produced blocks
87-
template <typename Func>
88-
BlockedSpace2d(std::size_t dim1, Func getter_size_dim2, std::size_t grain_size) {
90+
BlockedSpace2d(std::size_t dim1, std::function<std::size_t(std::size_t)> getter_size_dim2,
91+
std::size_t grain_size) {
8992
for (std::size_t i = 0; i < dim1; ++i) {
9093
std::size_t size = getter_size_dim2(i);
9194
// Each row (second dim) is divided into n_blocks
@@ -104,13 +107,13 @@ class BlockedSpace2d {
104107
}
105108

106109
// get index of the first dimension of i-th block(task)
107-
[[nodiscard]] std::size_t GetFirstDimension(size_t i) const {
110+
[[nodiscard]] std::size_t GetFirstDimension(std::size_t i) const {
108111
CHECK_LT(i, first_dimension_.size());
109112
return first_dimension_[i];
110113
}
111114

112115
// get a range of indexes for the second dimension of i-th block(task)
113-
[[nodiscard]] Range1d GetRange(size_t i) const {
116+
[[nodiscard]] Range1d GetRange(std::size_t i) const {
114117
CHECK_LT(i, ranges_.size());
115118
return ranges_[i];
116119
}
@@ -129,22 +132,22 @@ class BlockedSpace2d {
129132
}
130133

131134
std::vector<Range1d> ranges_;
132-
std::vector<size_t> first_dimension_;
135+
std::vector<std::size_t> first_dimension_;
133136
};
134137

135138

136139
// Wrapper to implement nested parallelism with simple omp parallel for
137-
template <typename Func>
138-
void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
140+
inline void ParallelFor2d(BlockedSpace2d const& space, std::int32_t n_threads,
141+
std::function<void(std::size_t, Range1d)> func) {
139142
std::size_t n_blocks_in_space = space.Size();
140-
CHECK_GE(nthreads, 1);
143+
CHECK_GE(n_threads, 1);
141144

142145
dmlc::OMPException exc;
143-
#pragma omp parallel num_threads(nthreads)
146+
#pragma omp parallel num_threads(n_threads)
144147
{
145148
exc.Run([&]() {
146-
size_t tid = omp_get_thread_num();
147-
size_t chunck_size = n_blocks_in_space / nthreads + !!(n_blocks_in_space % nthreads);
149+
std::size_t tid = omp_get_thread_num();
150+
std::size_t chunck_size = n_blocks_in_space / n_threads + !!(n_blocks_in_space % n_threads);
148151

149152
std::size_t begin = chunck_size * tid;
150153
std::size_t end = std::min(begin + chunck_size, n_blocks_in_space);

src/data/adapter.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,6 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo {
477477
ArrayInterface<1> indptr_;
478478
ArrayInterface<1> indices_;
479479
ArrayInterface<1> values_;
480-
bst_row_t n_rows_;
481480

482481
class Line {
483482
std::size_t column_idx_;
@@ -503,11 +502,8 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo {
503502
static constexpr bool kIsRowMajor = false;
504503

505504
CSCArrayAdapterBatch(ArrayInterface<1> indptr, ArrayInterface<1> indices,
506-
ArrayInterface<1> values, bst_row_t n_rows)
507-
: indptr_{std::move(indptr)},
508-
indices_{std::move(indices)},
509-
values_{std::move(values)},
510-
n_rows_{n_rows} {}
505+
ArrayInterface<1> values)
506+
: indptr_{std::move(indptr)}, indices_{std::move(indices)}, values_{std::move(values)} {}
511507

512508
std::size_t Size() const { return indptr_.n - 1; }
513509
Line GetLine(std::size_t idx) const {
@@ -542,8 +538,7 @@ class CSCArrayAdapter : public detail::SingleBatchDataIter<CSCArrayAdapterBatch>
542538
indices_{indices},
543539
values_{values},
544540
num_rows_{num_rows},
545-
batch_{
546-
CSCArrayAdapterBatch{indptr_, indices_, values_, static_cast<bst_row_t>(num_rows_)}} {}
541+
batch_{CSCArrayAdapterBatch{indptr_, indices_, values_}} {}
547542

548543
// JVM package sends 0 as unknown
549544
size_t NumRows() const { return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; }

src/tree/hist/evaluate_splits.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
55
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
66

7-
#include <algorithm> // for copy
8-
#include <cstddef> // for size_t
9-
#include <limits> // for numeric_limits
10-
#include <memory> // for shared_ptr
11-
#include <numeric> // for accumulate
12-
#include <utility> // for move
13-
#include <vector> // for vector
7+
#include <algorithm> // for copy
8+
#include <cstddef> // for size_t
9+
#include <limits> // for numeric_limits
10+
#include <memory> // for shared_ptr
11+
#include <numeric> // for accumulate
12+
#include <utility> // for move
13+
#include <vector> // for vector
1414

1515
#include "../../common/categorical.h" // for CatBitField
1616
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
@@ -20,6 +20,7 @@
2020
#include "../param.h" // for TrainParam
2121
#include "../split_evaluator.h" // for TreeEvaluator
2222
#include "expand_entry.h" // for MultiExpandEntry
23+
#include "hist_cache.h" // for BoundedHistCollection
2324
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
2425
#include "xgboost/context.h" // for COntext
2526
#include "xgboost/linalg.h" // for Constants, Vector
@@ -317,7 +318,7 @@ class HistEvaluator {
317318
}
318319

319320
public:
320-
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
321+
void EvaluateSplits(const BoundedHistCollection &hist, common::HistogramCuts const &cut,
321322
common::Span<FeatureType const> feature_types, const RegTree &tree,
322323
std::vector<CPUExpandEntry> *p_entries) {
323324
auto n_threads = ctx_->Threads();
@@ -623,7 +624,7 @@ class HistMultiEvaluator {
623624
}
624625

625626
public:
626-
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
627+
void EvaluateSplits(RegTree const &tree, common::Span<const BoundedHistCollection *> hist,
627628
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
628629
auto &entries = *p_entries;
629630
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());

0 commit comments

Comments
 (0)