Skip to content

Commit 8d7fe26

Browse files
authored
[EM] Enable access to the number of batches. (dmlc#10691)
- Expose `NumBatches` in `DMatrix`. - Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~. - Purge old external memory data generation code.
1 parent 033a666 commit 8d7fe26

26 files changed

+168
-351
lines changed

include/xgboost/data.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,12 @@ class DMatrix {
541541
[[nodiscard]] bool PageExists() const;
542542

543543
/**
544-
* @return Whether the data columns single column block.
544+
* @return Whether the contains a single batch.
545+
*
546+
* The naming is legacy.
545547
*/
546-
[[nodiscard]] virtual bool SingleColBlock() const = 0;
548+
[[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; }
549+
[[nodiscard]] virtual std::int32_t NumBatches() const { return 1; }
547550

548551
virtual ~DMatrix();
549552

src/common/device_helpers.cuh

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator<T> {
486486
} // namespace detail
487487

488488
template <typename T>
489-
using TypedDiscard =
490-
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
491-
detail::TypedDiscard<T>>;
489+
using TypedDiscard = std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
490+
detail::TypedDiscard<T>>;
492491

493492
template <typename VectorT, typename T = typename VectorT::value_type,
494-
typename IndexT = typename xgboost::common::Span<T>::index_type>
495-
xgboost::common::Span<T> ToSpan(
496-
VectorT &vec,
497-
IndexT offset = 0,
498-
IndexT size = std::numeric_limits<size_t>::max()) {
493+
typename IndexT = typename xgboost::common::Span<T>::index_type>
494+
xgboost::common::Span<T> ToSpan(VectorT &vec, IndexT offset = 0,
495+
IndexT size = std::numeric_limits<size_t>::max()) {
499496
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
500497
CHECK_LE(offset + size, vec.size());
501-
return {vec.data().get() + offset, size};
498+
return {thrust::raw_pointer_cast(vec.data()) + offset, size};
502499
}
503500

504501
template <typename T>
505-
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
506-
size_t offset, size_t size) {
502+
xgboost::common::Span<T> ToSpan(thrust::device_vector<T> &vec, size_t offset, size_t size) {
507503
return ToSpan(vec, offset, size);
508504
}
509505

@@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
874870

875871
// Changing this has effect on prediction return, where we need to pass the pointer to
876872
// third-party libraries like cuPy
877-
inline CUDAStreamView DefaultStream() {
878-
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
879-
return CUDAStreamView{cudaStreamPerThread};
880-
#else
881-
return CUDAStreamView{cudaStreamLegacy};
882-
#endif
883-
}
873+
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }
884874

885875
class CUDAStream {
886876
cudaStream_t stream_;

src/data/extmem_quantile_dmatrix.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
7474
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
7575
ext_info.SetInfo(ctx, &this->info_);
7676

77+
this->n_batches_ = ext_info.n_batches;
78+
7779
/**
7880
* Generate quantiles
7981
*/

src/data/extmem_quantile_dmatrix.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
3333
std::string cache, bst_bin_t max_bin, bool on_host);
3434
~ExtMemQuantileDMatrix() override;
3535

36-
[[nodiscard]] bool SingleColBlock() const override { return false; }
36+
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
3737

3838
private:
3939
void InitFromCPU(
@@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
6363
std::string cache_prefix_;
6464
bool on_host_;
6565
BatchParam batch_;
66+
bst_idx_t n_batches_{0};
6667

6768
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
6869
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;

src/data/iterative_dmatrix.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix {
5757

5858
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
5959
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const &param) override;
60-
61-
bool SingleColBlock() const override { return true; }
6260
};
6361
} // namespace data
6462
} // namespace xgboost

src/data/proxy_dmatrix.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix {
9494
MetaInfo const& Info() const override { return info_; }
9595
Context const* Ctx() const override { return &ctx_; }
9696

97-
bool SingleColBlock() const override { return false; }
9897
bool EllpackExists() const override { return false; }
9998
bool GHistIndexExists() const override { return false; }
10099
bool SparsePageExists() const override { return false; }

src/data/simple_dmatrix.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix {
3333
const MetaInfo& Info() const override;
3434
Context const* Ctx() const override { return &fmat_ctx_; }
3535

36-
bool SingleColBlock() const override { return true; }
3736
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
3837
DMatrix* SliceCol(int num_slices, int slice_id) override;
3938

src/data/sparse_page_dmatrix.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix {
9090
[[nodiscard]] MetaInfo &Info() override;
9191
[[nodiscard]] const MetaInfo &Info() const override;
9292
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
93-
// The only DMatrix implementation that returns false.
94-
[[nodiscard]] bool SingleColBlock() const override { return false; }
93+
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
9594
DMatrix *Slice(common::Span<std::int32_t const>) override {
9695
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
9796
return nullptr;

src/data/sparse_page_source.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
*/
44
#include "sparse_page_source.h"
55

6-
#include <filesystem> // for exists
7-
#include <string> // for string
86
#include <cstdio> // for remove
7+
#include <filesystem> // for exists
98
#include <numeric> // for partial_sum
9+
#include <string> // for string
1010

1111
namespace xgboost::data {
1212
void Cache::Commit() {
@@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) {
2727
<< "; you may want to remove it manually";
2828
}
2929
}
30+
31+
#if !defined(XGBOOST_USE_CUDA)
32+
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
33+
#endif
3034
} // namespace xgboost::data

src/data/sparse_page_source.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
1818
cuda_impl::Dispatch(proxy,
1919
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
2020
}
21+
22+
void InitNewThread::operator()() const {
23+
*GlobalConfigThreadLocalStore::Get() = config;
24+
// For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the
25+
// stream when creating a new thread in the thread pool. While for CUDA 11.8, this
26+
// action might cause an insufficient driver version error for some reason. Lastly, it
27+
// should work with CUDA 12.5 without any action being taken.
28+
29+
// dh::DefaultStream().Sync();
30+
}
2131
} // namespace xgboost::data

0 commit comments

Comments
 (0)