Skip to content

Commit 2ecc85f

Browse files
authored
[EM] Support ExtMemQdm in the GPU predictor. (dmlc#10694)
1 parent 4370454 commit 2ecc85f

File tree

6 files changed

+124
-129
lines changed

6 files changed

+124
-129
lines changed

include/xgboost/c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
494494
* - missing: Which value to represent missing value
495495
* - nthread (optional): Number of threads used for initializing DMatrix.
496496
* - max_bin (optional): Maximum number of bins for building histogram.
497-
* \param out The created Device Quantile DMatrix
497+
* \param out The created Quantile DMatrix.
498498
*
499499
* \return 0 when success, -1 when failure happens
500500
*/

include/xgboost/data.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class MetaInfo {
7272
* if specified, xgboost will start from this init margin
7373
* can be used to specify initial prediction to boost from.
7474
*/
75-
linalg::Tensor<float, 2> base_margin_; // NOLINT
75+
linalg::Matrix<float> base_margin_; // NOLINT
7676
/*!
7777
* \brief lower bound of the label, to be used for survival analysis (censored regression)
7878
*/

include/xgboost/predictor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2023 by Contributors
2+
* Copyright 2017-2024, XGBoost Contributors
33
* \file predictor.h
44
* \brief Interface of predictor,
55
* performs predictions for a gradient booster.
@@ -15,7 +15,6 @@
1515
#include <functional> // for function
1616
#include <memory> // for shared_ptr
1717
#include <string>
18-
#include <utility> // for make_pair
1918
#include <vector>
2019

2120
// Forward declarations

src/data/ellpack_page.cuh

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,26 @@ struct EllpackDeviceAccessor {
6060
min_fvalue = cuts->min_vals_.ConstHostSpan();
6161
}
6262
}
63-
// Get a matrix element, uses binary search for look up Return NaN if missing
64-
// Given a row index and a feature index, returns the corresponding cut value
65-
[[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
66-
ridx -= base_rowid;
63+
/**
64+
* @brief Given a row index and a feature index, returns the corresponding cut value.
65+
*
66+
* Uses binary search for look up. Returns NaN if missing.
67+
*
68+
* @tparam global_ridx Whether the row index is global to all ellpack batches or it's
69+
* local to the current batch.
70+
*/
71+
template <bool global_ridx = true>
72+
[[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const {
73+
if (global_ridx) {
74+
ridx -= base_rowid;
75+
}
6776
auto row_begin = row_stride * ridx;
6877
auto row_end = row_begin + row_stride;
69-
auto gidx = -1;
78+
bst_bin_t gidx = -1;
7079
if (is_dense) {
7180
gidx = gidx_iter[row_begin + fidx];
7281
} else {
73-
gidx = common::BinarySearchBin(row_begin,
74-
row_end,
75-
gidx_iter,
76-
feature_segments[fidx],
82+
gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx],
7783
feature_segments[fidx + 1]);
7884
}
7985
return gidx;

0 commit comments

Comments
 (0)