6
6
#include < vector> // for vector
7
7
8
8
#include " ../collective/allreduce.h" // for Allreduce
9
+ #include " ../common/cuda_context.cuh" // for CUDAContext
9
10
#include " ../common/cuda_rt_utils.h" // for AllVisibleGPUs
11
+ #include " ../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
10
12
#include " ../common/device_vector.cuh" // for XGBCachingDeviceAllocator
11
13
#include " ../common/hist_util.cuh" // for AdapterDeviceSketch
12
14
#include " ../common/quantile.cuh" // for SketchContainer
@@ -26,8 +28,10 @@ void MakeSketches(Context const* ctx,
26
28
DMatrixProxy* proxy, std::shared_ptr<DMatrix> ref, BatchParam const & p,
27
29
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const & info,
28
30
ExternalDataInfo* p_ext_info) {
29
- dh::XGBCachingDeviceAllocator<char > alloc;
30
- std::vector<common::SketchContainer> sketch_containers;
31
+ xgboost_NVTX_FN_RANGE ();
32
+
33
+ CUDAContext const * cuctx = ctx->CUDACtx ();
34
+ std::unique_ptr<common::SketchContainer> sketch;
31
35
auto & ext_info = *p_ext_info;
32
36
33
37
do {
@@ -44,12 +48,14 @@ void MakeSketches(Context const* ctx,
44
48
<< " Inconsistent number of columns." ;
45
49
}
46
50
if (!ref) {
47
- sketch_containers.emplace_back (proxy->Info ().feature_types , p.max_bin , ext_info.n_features ,
48
- data::BatchSamples (proxy), dh::GetDevice (ctx));
49
- auto * p_sketch = &sketch_containers.back ();
51
+ if (!sketch) {
52
+ sketch = std::make_unique<common::SketchContainer>(
53
+ proxy->Info ().feature_types , p.max_bin , ext_info.n_features , data::BatchSamples (proxy),
54
+ dh::GetDevice (ctx));
55
+ }
50
56
proxy->Info ().weights_ .SetDevice (dh::GetDevice (ctx));
51
57
cuda_impl::Dispatch (proxy, [&](auto const & value) {
52
- common::AdapterDeviceSketch (value, p.max_bin , proxy->Info (), missing, p_sketch );
58
+ common::AdapterDeviceSketch (value, p.max_bin , proxy->Info (), missing, sketch. get () );
53
59
});
54
60
}
55
61
auto batch_rows = data::BatchSamples (proxy);
@@ -60,7 +66,7 @@ void MakeSketches(Context const* ctx,
60
66
std::max (ext_info.row_stride , cuda_impl::Dispatch (proxy, [=](auto const & value) {
61
67
return GetRowCounts (value, row_counts_span, dh::GetDevice (ctx), missing);
62
68
}));
63
- ext_info.nnz += thrust::reduce (thrust::cuda::par (alloc ), row_counts.begin (), row_counts.end ());
69
+ ext_info.nnz += thrust::reduce (cuctx-> CTP ( ), row_counts.begin (), row_counts.end ());
64
70
ext_info.n_batches ++;
65
71
ext_info.base_rows .push_back (batch_rows);
66
72
} while (iter->Next ());
@@ -73,18 +79,7 @@ void MakeSketches(Context const* ctx,
73
79
// Get reference
74
80
dh::safe_cuda (cudaSetDevice (dh::GetDevice (ctx).ordinal ));
75
81
if (!ref) {
76
- HostDeviceVector<FeatureType> ft;
77
- common::SketchContainer final_sketch (
78
- sketch_containers.empty () ? ft : sketch_containers.front ().FeatureTypes (), p.max_bin ,
79
- ext_info.n_features , ext_info.accumulated_rows , dh::GetDevice (ctx));
80
- for (auto const & sketch : sketch_containers) {
81
- final_sketch.Merge (sketch.ColumnsPtr (), sketch.Data ());
82
- final_sketch.FixError ();
83
- }
84
- sketch_containers.clear ();
85
- sketch_containers.shrink_to_fit ();
86
-
87
- final_sketch.MakeCuts (ctx, cuts.get (), info.IsColumnSplit ());
82
+ sketch->MakeCuts (ctx, cuts.get (), info.IsColumnSplit ());
88
83
} else {
89
84
GetCutsFromRef (ctx, ref, ext_info.n_features , p, cuts.get ());
90
85
}
0 commit comments