Skip to content

Commit 7510a87

Browse files
authored
[EM] Reuse the quantile container. (dmlc#10761)
Use the push method to merge the quantiles instead of creating multiple containers. This reduces the memory usage by consistent pruning.
1 parent 4fe67f1 commit 7510a87

File tree

3 files changed

+18
-21
lines changed

3 files changed

+18
-21
lines changed

src/common/algorithm.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
#include "common.h" // safe_cuda
1919
#include "cuda_context.cuh" // CUDAContext
20-
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector
20+
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota
21+
#include "device_vector.cuh" // for device_vector
2122
#include "xgboost/base.h" // XGBOOST_DEVICE
2223
#include "xgboost/context.h" // Context
2324
#include "xgboost/logging.h" // CHECK

src/common/quantile.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
182182
merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple {
183183
auto ind = get_ind(t); // == 0 if element is from x
184184
// x_counter, y_counter
185-
return thrust::tuple<std::uint64_t, std::uint64_t>{!ind, ind};
185+
return thrust::make_tuple(static_cast<std::uint64_t>(!ind),
186+
static_cast<std::uint64_t>(ind));
186187
});
187188

188189
// Compute the index for both x and y (which of the element in a and b are used in each

src/data/quantile_dmatrix.cu

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include <vector> // for vector
77

88
#include "../collective/allreduce.h" // for Allreduce
9+
#include "../common/cuda_context.cuh" // for CUDAContext
910
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
11+
#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
1012
#include "../common/device_vector.cuh" // for XGBCachingDeviceAllocator
1113
#include "../common/hist_util.cuh" // for AdapterDeviceSketch
1214
#include "../common/quantile.cuh" // for SketchContainer
@@ -26,8 +28,10 @@ void MakeSketches(Context const* ctx,
2628
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const& p,
2729
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
2830
ExternalDataInfo* p_ext_info) {
29-
dh::XGBCachingDeviceAllocator<char> alloc;
30-
std::vector<common::SketchContainer> sketch_containers;
31+
xgboost_NVTX_FN_RANGE();
32+
33+
CUDAContext const* cuctx = ctx->CUDACtx();
34+
std::unique_ptr<common::SketchContainer> sketch;
3135
auto& ext_info = *p_ext_info;
3236

3337
do {
@@ -44,12 +48,14 @@ void MakeSketches(Context const* ctx,
4448
<< "Inconsistent number of columns.";
4549
}
4650
if (!ref) {
47-
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, ext_info.n_features,
48-
data::BatchSamples(proxy), dh::GetDevice(ctx));
49-
auto* p_sketch = &sketch_containers.back();
51+
if (!sketch) {
52+
sketch = std::make_unique<common::SketchContainer>(
53+
proxy->Info().feature_types, p.max_bin, ext_info.n_features, data::BatchSamples(proxy),
54+
dh::GetDevice(ctx));
55+
}
5056
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
5157
cuda_impl::Dispatch(proxy, [&](auto const& value) {
52-
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
58+
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get());
5359
});
5460
}
5561
auto batch_rows = data::BatchSamples(proxy);
@@ -60,7 +66,7 @@ void MakeSketches(Context const* ctx,
6066
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
6167
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
6268
}));
63-
ext_info.nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
69+
ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end());
6470
ext_info.n_batches++;
6571
ext_info.base_rows.push_back(batch_rows);
6672
} while (iter->Next());
@@ -73,18 +79,7 @@ void MakeSketches(Context const* ctx,
7379
// Get reference
7480
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
7581
if (!ref) {
76-
HostDeviceVector<FeatureType> ft;
77-
common::SketchContainer final_sketch(
78-
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin,
79-
ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx));
80-
for (auto const& sketch : sketch_containers) {
81-
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
82-
final_sketch.FixError();
83-
}
84-
sketch_containers.clear();
85-
sketch_containers.shrink_to_fit();
86-
87-
final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
82+
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
8883
} else {
8984
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
9085
}

0 commit comments

Comments
 (0)