Skip to content

Commit d6ebcfb

Browse files
authored
[EM] Support CPU quantile objective for external memory. (dmlc#10751)
1 parent 12c6b7c commit d6ebcfb

File tree

13 files changed

+163
-36
lines changed

13 files changed

+163
-36
lines changed

python-package/xgboost/testing/updater.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,37 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
163163
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
164164

165165

166+
def check_quantile_loss_extmem(
167+
n_samples_per_batch: int,
168+
n_features: int,
169+
n_batches: int,
170+
tree_method: str,
171+
device: str,
172+
) -> None:
173+
"""Check external memory with the quantile objective."""
174+
it = tm.IteratorForTest(
175+
*tm.make_batches(n_samples_per_batch, n_features, n_batches, device != "cpu"),
176+
cache="cache",
177+
on_host=False,
178+
)
179+
Xy_it = xgb.DMatrix(it)
180+
params = {
181+
"tree_method": tree_method,
182+
"objective": "reg:quantileerror",
183+
"device": device,
184+
"quantile_alpha": [0.2, 0.8],
185+
}
186+
booster_it = xgb.train(params, Xy_it)
187+
X, y, w = it.as_arrays()
188+
Xy = xgb.DMatrix(X, y, weight=w)
189+
booster = xgb.train(params, Xy)
190+
191+
predt_it = booster_it.predict(Xy_it)
192+
predt = booster.predict(Xy)
193+
194+
np.testing.assert_allclose(predt, predt_it)
195+
196+
166197
def check_cut(
167198
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
168199
) -> None:

src/common/error_msg.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ constexpr StringView InconsistentMaxBin() {
4141
"and consistent with the Booster being trained.";
4242
}
4343

44+
constexpr StringView InvalidMaxBin() { return "`max_bin` must be equal to or greater than 2."; }
45+
4446
constexpr StringView UnknownDevice() { return "Unknown device type."; }
4547

4648
inline void MaxFeatureSize(std::uint64_t n_features) {

src/common/partition_builder.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -367,23 +367,21 @@ class PartitionBuilder {
367367
// Copy row partitions into global cache for reuse in objective
368368
template <typename Invalidp>
369369
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
370-
std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
371-
auto& h_pos = *p_position;
372-
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
373-
370+
Span<bst_node_t> position, Invalidp invalidp) const {
374371
auto p_begin = row_set.Data()->data();
375372
// For each node, walk through all the samples that fall in this node.
376-
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
373+
auto p_pos = position.data();
374+
ParallelFor(row_set.Size(), ctx->Threads(), [&](auto i) {
377375
auto const& node = row_set[i];
378376
if (node.node_id < 0) {
379377
return;
380378
}
381379
CHECK(tree.IsLeaf(node.node_id));
382380
if (node.begin()) { // guard for empty node.
383-
size_t ptr_offset = node.end() - p_begin;
381+
std::size_t ptr_offset = node.end() - p_begin;
384382
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
385383
for (auto idx = node.begin(); idx != node.end(); ++idx) {
386-
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
384+
p_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
387385
}
388386
}
389387
});

src/common/quantile.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <utility>
99

1010
#include "../collective/aggregator.h"
11+
#include "../common/error_msg.h" // for InvalidMaxBin
1112
#include "../data/adapter.h"
1213
#include "categorical.h"
1314
#include "hist_util.h"
@@ -16,15 +17,16 @@ namespace xgboost::common {
1617
template <typename WQSketch>
1718
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
1819
std::vector<bst_idx_t> columns_size,
19-
int32_t max_bins,
20+
bst_bin_t max_bin,
2021
Span<FeatureType const> feature_types,
2122
bool use_group)
2223
: feature_types_(feature_types.cbegin(), feature_types.cend()),
2324
columns_size_{std::move(columns_size)},
24-
max_bins_{max_bins},
25+
max_bins_{max_bin},
2526
use_group_ind_{use_group},
2627
n_threads_{ctx->Threads()} {
2728
monitor_.Init(__func__);
29+
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
2830
CHECK_NE(columns_size_.size(), 0);
2931
sketches_.resize(columns_size_.size());
3032
CHECK_GE(n_threads_, 1);

src/common/quantile.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "categorical.h"
1010
#include "device_helpers.cuh"
11+
#include "error_msg.h" // for InvalidMaxBin
1112
#include "quantile.h"
1213
#include "timer.h"
1314
#include "xgboost/data.h"
@@ -96,7 +97,7 @@ class SketchContainer {
9697
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
9798
* \param device GPU ID.
9899
*/
99-
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin,
100+
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
100101
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
101102
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
102103
CHECK(device.IsCUDA());
@@ -117,6 +118,7 @@ class SketchContainer {
117118
has_categorical_ =
118119
!d_feature_types.empty() &&
119120
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{});
121+
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
120122

121123
timer_.Init(__func__);
122124
}

src/common/quantile.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,10 @@ class SketchContainerImpl {
802802
/* \brief Initialize necessary info.
803803
*
804804
* \param columns_size Size of each column.
805-
* \param max_bins maximum number of bins for each feature.
805+
* \param max_bin maximum number of bins for each feature.
806806
* \param use_group whether is assigned to group to data instance.
807807
*/
808-
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins,
808+
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bin,
809809
common::Span<FeatureType const> feature_types, bool use_group);
810810

811811
static bool UseGroup(MetaInfo const &info) {

src/gbm/gbtree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
218218
model_.learner_model_param->OutputLength());
219219
CHECK_NE(n_groups, 0);
220220

221-
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
221+
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) {
222222
LOG(FATAL) << "Current objective doesn't support external memory.";
223223
}
224224

src/tree/common_row_partitioner.h

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -301,34 +301,37 @@ class CommonRowPartitioner {
301301
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
302302

303303
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
304-
std::vector<bst_node_t>* p_out_position) const {
305-
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
306-
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
304+
common::Span<bst_node_t> out_position) const {
305+
partition_builder_.LeafPartition(
306+
ctx, tree, this->Partitions(), out_position,
307+
[&](size_t idx) -> bool { return hess[idx - this->base_rowid] - .0f == .0f; });
307308
}
308309

309310
void LeafPartition(Context const* ctx, RegTree const& tree,
310311
linalg::TensorView<GradientPair const, 2> gpair,
311-
std::vector<bst_node_t>* p_out_position) const {
312+
common::Span<bst_node_t> out_position) const {
312313
if (gpair.Shape(1) > 1) {
313314
partition_builder_.LeafPartition(
314-
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
315-
auto sample = gpair.Slice(idx, linalg::All());
315+
ctx, tree, this->Partitions(), out_position, [&](std::size_t idx) -> bool {
316+
auto sample = gpair.Slice(idx - this->base_rowid, linalg::All());
316317
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
317318
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
318319
});
319320
} else {
320321
auto s = gpair.Slice(linalg::All(), 0);
321-
partition_builder_.LeafPartition(
322-
ctx, tree, this->Partitions(), p_out_position,
323-
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
322+
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
323+
[&](std::size_t idx) -> bool {
324+
return s(idx - this->base_rowid).GetHess() - .0f == .0f;
325+
});
324326
}
325327
}
326328
void LeafPartition(Context const* ctx, RegTree const& tree,
327329
common::Span<GradientPair const> gpair,
328-
std::vector<bst_node_t>* p_out_position) const {
329-
partition_builder_.LeafPartition(
330-
ctx, tree, this->Partitions(), p_out_position,
331-
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
330+
common::Span<bst_node_t> out_position) const {
331+
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
332+
[&](std::size_t idx) -> bool {
333+
return gpair[idx - this->base_rowid].GetHess() - .0f == .0f;
334+
});
332335
}
333336

334337
private:

src/tree/updater_approx.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,10 @@ class GlobalApproxBuilder {
154154
if (!task_->UpdateTreeLeaf()) {
155155
return;
156156
}
157+
p_out_position->resize(hess.size());
157158
for (auto const &part : partitioner_) {
158-
part.LeafPartition(ctx_, tree, hess, p_out_position);
159+
part.LeafPartition(ctx_, tree, hess,
160+
common::Span{p_out_position->data(), p_out_position->size()});
159161
}
160162
monitor_->Stop(__func__);
161163
}

src/tree/updater_quantile_hist.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class MultiTargetHistBuilder {
126126
std::vector<CommonRowPartitioner> partitioner_;
127127
// Pointer to last updated tree, used for update prediction cache.
128128
RegTree const *p_last_tree_{nullptr};
129-
DMatrix const * p_last_fmat_{nullptr};
129+
DMatrix const *p_last_fmat_{nullptr};
130130

131131
ObjInfo const *task_{nullptr};
132132

@@ -254,8 +254,10 @@ class MultiTargetHistBuilder {
254254
monitor_->Stop(__func__);
255255
return;
256256
}
257+
p_out_position->resize(gpair.Shape(0));
257258
for (auto const &part : partitioner_) {
258-
part.LeafPartition(ctx_, tree, gpair, p_out_position);
259+
part.LeafPartition(ctx_, tree, gpair,
260+
common::Span{p_out_position->data(), p_out_position->size()});
259261
}
260262
monitor_->Stop(__func__);
261263
}
@@ -461,8 +463,10 @@ class HistUpdater {
461463
monitor_->Stop(__func__);
462464
return;
463465
}
466+
p_out_position->resize(gpair.Shape(0));
464467
for (auto const &part : partitioner_) {
465-
part.LeafPartition(ctx_, tree, gpair, p_out_position);
468+
part.LeafPartition(ctx_, tree, gpair,
469+
common::Span{p_out_position->data(), p_out_position->size()});
466470
}
467471
monitor_->Stop(__func__);
468472
}
@@ -521,7 +525,9 @@ class QuantileHistMaker : public TreeUpdater {
521525

522526
linalg::Matrix<GradientPair> sample_out;
523527
auto h_sample_out = h_gpair;
524-
auto need_copy = [&] { return trees.size() > 1 || n_targets > 1; };
528+
auto need_copy = [&] {
529+
return trees.size() > 1 || n_targets > 1;
530+
};
525531
if (need_copy()) {
526532
// allocate buffer
527533
sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF};

0 commit comments

Comments
 (0)