Skip to content

Commit 34937fe

Browse files
authored
[EM] Python wrapper for the ExtMemQuantileDMatrix. (dmlc#10762)
Not exposed to the document yet. - Add C API. - Add Python API. - Basic CPU tests.
1 parent 7510a87 commit 34937fe

File tree

7 files changed

+208
-27
lines changed

7 files changed

+208
-27
lines changed

include/xgboost/c_api.h

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -472,37 +472,66 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
472472
* @example external_memory.c
473473
*/
474474

475-
/*!
476-
* \brief Create a Quantile DMatrix with data iterator.
475+
/**
476+
* @brief Create a Quantile DMatrix with data iterator.
477477
*
478478
* Short note for how to use the second set of callback for (GPU)Hist tree method:
479479
*
480480
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
481-
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
481+
* - Step 1: Create a DMatrix proxy by @ref XGProxyDMatrixCreate and hold the handle.
482482
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
483483
* `XGQuantileDMatrixCreateFromCallback`.
484484
* - Step 3: Call appropriate data setters in `next` functions.
485485
*
486486
* See test_iterative_dmatrix.cu or Python interface for examples.
487487
*
488-
* \param iter A handle to external data iterator.
489-
* \param proxy A DMatrix proxy handle created by \ref XGProxyDMatrixCreate.
490-
* \param ref Reference DMatrix for providing quantile information.
491-
* \param reset Callback function resetting the iterator state.
492-
* \param next Callback function yielding the next batch of data.
493-
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
488+
* @param iter A handle to external data iterator.
489+
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
490+
* @param ref Reference DMatrix for providing quantile information.
491+
* @param reset Callback function resetting the iterator state.
492+
* @param next Callback function yielding the next batch of data.
493+
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
494494
* - missing: Which value to represent missing value
495495
* - nthread (optional): Number of threads used for initializing DMatrix.
496-
* - max_bin (optional): Maximum number of bins for building histogram.
497-
* \param out The created Quantile DMatrix.
496+
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
497+
the corresponding booster training parameter.
498+
* @param out The created Quantile DMatrix.
498499
*
499-
* \return 0 when success, -1 when failure happens
500+
* @return 0 when success, -1 when failure happens
500501
*/
501502
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
502503
DataIterHandle ref, DataIterResetCallback *reset,
503504
XGDMatrixCallbackNext *next, char const *config,
504505
DMatrixHandle *out);
505506

507+
/**
508+
* @brief Create a Quantile DMatrix backed by external memory.
509+
*
510+
* @since 3.0.0
511+
*
512+
* @note This is still under development, not ready for test yet.
513+
*
514+
* @param iter A handle to external data iterator.
515+
* @param proxy A DMatrix proxy handle created by @ref XGProxyDMatrixCreate.
516+
* @param ref Reference DMatrix for providing quantile information.
517+
* @param reset Callback function resetting the iterator state.
518+
* @param next Callback function yielding the next batch of data.
519+
* @param config JSON encoded parameters for DMatrix construction. Accepted fields are:
520+
* - missing: Which value to represent missing value
521+
* - cache_prefix: The path of cache file, caller must initialize all the directories in this path.
522+
* - nthread (optional): Number of threads used for initializing DMatrix.
523+
* - max_bin (optional): Maximum number of bins for building histogram. Must be consistent with
524+
the corresponding booster training parameter.
525+
* @param out The created Quantile DMatrix.
526+
*
527+
* @return 0 when success, -1 when failure happens
528+
*/
529+
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
530+
DataIterHandle ref,
531+
DataIterResetCallback *reset,
532+
XGDMatrixCallbackNext *next,
533+
char const *config, DMatrixHandle *out);
534+
506535
/*!
507536
* \brief Create a Device Quantile DMatrix with data iterator.
508537
* \deprecated since 1.7.0

python-package/xgboost/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55

66
from . import tracker # noqa
77
from . import collective, dask
8-
from .core import Booster, DataIter, DMatrix, QuantileDMatrix, _py_version, build_info
8+
from .core import (
9+
Booster,
10+
DataIter,
11+
DMatrix,
12+
ExtMemQuantileDMatrix,
13+
QuantileDMatrix,
14+
_py_version,
15+
build_info,
16+
)
917
from .tracker import RabitTracker # noqa
1018
from .training import cv, train
1119

@@ -31,6 +39,7 @@
3139
# core
3240
"DMatrix",
3341
"QuantileDMatrix",
42+
"ExtMemQuantileDMatrix",
3443
"Booster",
3544
"DataIter",
3645
"train",

python-package/xgboost/core.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,13 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
526526
on_host :
527527
Whether the data should be cached on host memory instead of harddrive when using
528528
GPU with external memory. If set to true, then the "external memory" would
529-
simply be CPU (host) memory. This is still working in progress, not ready for
530-
test yet.
529+
simply be CPU (host) memory.
530+
531+
.. versionadded:: 3.0.0
532+
533+
.. warning::
534+
535+
This is still working in progress, not ready for test yet.
531536
532537
"""
533538

@@ -927,8 +932,7 @@ def __init__(
927932
if feature_types is not None:
928933
self.feature_types = feature_types
929934

930-
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None:
931-
it = iterator
935+
def _init_from_iter(self, it: DataIter, enable_categorical: bool) -> None:
932936
args = make_jcargs(
933937
missing=self.missing,
934938
nthread=self.nthread,
@@ -1673,6 +1677,63 @@ def _init(
16731677
self.handle = handle
16741678

16751679

1680+
class ExtMemQuantileDMatrix(DMatrix):
1681+
"""The external memory version of the :py:class:`QuantileDMatrix`.
1682+
1683+
.. warning::
1684+
1685+
This is still working in progress, not ready for test yet.
1686+
1687+
.. versionadded:: 3.0.0
1688+
1689+
"""
1690+
1691+
@_deprecate_positional_args
1692+
def __init__( # pylint: disable=super-init-not-called
1693+
self,
1694+
data: DataIter,
1695+
missing: Optional[float] = None,
1696+
nthread: Optional[int] = None,
1697+
max_bin: Optional[int] = None,
1698+
ref: Optional[DMatrix] = None,
1699+
enable_categorical: bool = False,
1700+
) -> None:
1701+
self.max_bin = max_bin
1702+
self.missing = missing if missing is not None else np.nan
1703+
self.nthread = nthread if nthread is not None else -1
1704+
1705+
self._init(data, ref, enable_categorical)
1706+
assert self.handle is not None
1707+
1708+
def _init(
1709+
self, it: DataIter, ref: Optional[DMatrix], enable_categorical: bool
1710+
) -> None:
1711+
args = make_jcargs(
1712+
missing=self.missing,
1713+
nthread=self.nthread,
1714+
cache_prefix=it.cache_prefix if it.cache_prefix else "",
1715+
on_host=it.on_host,
1716+
)
1717+
handle = ctypes.c_void_p()
1718+
reset_callback, next_callback = it.get_callbacks(enable_categorical)
1719+
# We don't need the iter handle (hence None) in Python as reset,next callbacks
1720+
# are member functions, and ctypes can handle the `self` parameter
1721+
# automatically.
1722+
ret = _LIB.XGExtMemQuantileDMatrixCreateFromCallback(
1723+
None, # iter
1724+
it.proxy.handle, # proxy
1725+
ref.handle if ref is not None else ref, # ref
1726+
reset_callback, # reset
1727+
next_callback, # next
1728+
args, # config
1729+
ctypes.byref(handle), # out
1730+
)
1731+
it.reraise()
1732+
# delay check_call to throw intermediate exception first
1733+
_check_call(ret)
1734+
self.handle = handle
1735+
1736+
16761737
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
16771738
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
16781739

python-package/xgboost/testing/updater.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, List
66

77
import numpy as np
8+
import pytest
89

910
import xgboost as xgb
1011
import xgboost.testing as tm
@@ -194,6 +195,43 @@ def check_quantile_loss_extmem(
194195
np.testing.assert_allclose(predt, predt_it)
195196

196197

198+
def check_extmem_qdm(
199+
n_samples_per_batch: int,
200+
n_features: int,
201+
n_batches: int,
202+
device: str,
203+
on_host: bool,
204+
) -> None:
205+
"""Basic test for the `ExtMemQuantileDMatrix`."""
206+
207+
it = tm.IteratorForTest(
208+
*tm.make_batches(
209+
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
210+
),
211+
cache="cache",
212+
on_host=on_host,
213+
)
214+
Xy_it = xgb.ExtMemQuantileDMatrix(it)
215+
with pytest.raises(ValueError, match="Only the `hist`"):
216+
booster_it = xgb.train(
217+
{"device": device, "tree_method": "approx"}, Xy_it, num_boost_round=8
218+
)
219+
220+
booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
221+
X, y, w = it.as_arrays()
222+
Xy = xgb.QuantileDMatrix(X, y, weight=w)
223+
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
224+
225+
cut_it = Xy_it.get_quantile_cut()
226+
cut = Xy.get_quantile_cut()
227+
np.testing.assert_allclose(cut_it[0], cut[0])
228+
np.testing.assert_allclose(cut_it[1], cut[1])
229+
230+
predt_it = booster_it.predict(Xy_it)
231+
predt = booster.predict(Xy)
232+
np.testing.assert_allclose(predt_it, predt)
233+
234+
197235
def check_cut(
198236
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
199237
) -> None:

src/c_api/c_api.cc

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
296296
auto jconfig = Json::Load(StringView{config});
297297
auto missing = GetMissing(jconfig);
298298
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
299-
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
300-
auto on_host = OptionalArg<Boolean, bool>(jconfig, "on_host", false);
299+
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
300+
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
301301

302302
xgboost_CHECK_C_ARG_PTR(next);
303303
xgboost_CHECK_C_ARG_PTR(reset);
@@ -308,6 +308,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
308308
API_END();
309309
}
310310

311+
311312
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
312313
DataIterResetCallback *reset,
313314
XGDMatrixCallbackNext *next, float missing,
@@ -320,11 +321,8 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
320321
API_END();
321322
}
322323

323-
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
324-
DataIterHandle ref, DataIterResetCallback *reset,
325-
XGDMatrixCallbackNext *next, char const *config,
326-
DMatrixHandle *out) {
327-
API_BEGIN();
324+
namespace {
325+
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
328326
std::shared_ptr<DMatrix> _ref{nullptr};
329327
if (ref) {
330328
auto pp_ref = static_cast<std::shared_ptr<xgboost::DMatrix> *>(ref);
@@ -333,6 +331,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
333331
_ref = *pp_ref;
334332
CHECK(_ref) << err;
335333
}
334+
return _ref;
335+
}
336+
} // namespace
337+
338+
XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
339+
DataIterHandle ref, DataIterResetCallback *reset,
340+
XGDMatrixCallbackNext *next, char const *config,
341+
DMatrixHandle *out) {
342+
API_BEGIN();
343+
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
336344

337345
xgboost_CHECK_C_ARG_PTR(config);
338346
auto jconfig = Json::Load(StringView{config});
@@ -345,7 +353,32 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
345353
xgboost_CHECK_C_ARG_PTR(out);
346354

347355
*out = new std::shared_ptr<xgboost::DMatrix>{
348-
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
356+
xgboost::DMatrix::Create(iter, proxy, p_ref, reset, next, missing, n_threads, max_bin)};
357+
API_END();
358+
}
359+
360+
XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
361+
DataIterHandle ref,
362+
DataIterResetCallback *reset,
363+
XGDMatrixCallbackNext *next,
364+
char const *config, DMatrixHandle *out) {
365+
API_BEGIN();
366+
std::shared_ptr<DMatrix> p_ref{GetRefDMatrix(ref)};
367+
368+
xgboost_CHECK_C_ARG_PTR(config);
369+
auto jconfig = Json::Load(StringView{config});
370+
auto missing = GetMissing(jconfig);
371+
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
372+
auto max_bin = OptionalArg<Integer, std::int64_t>(jconfig, "max_bin", 256);
373+
auto on_host = OptionalArg<Boolean>(jconfig, "on_host", false);
374+
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
375+
376+
xgboost_CHECK_C_ARG_PTR(next);
377+
xgboost_CHECK_C_ARG_PTR(reset);
378+
xgboost_CHECK_C_ARG_PTR(out);
379+
380+
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
381+
iter, proxy, p_ref, reset, next, missing, n_threads, max_bin, cache, on_host)};
349382
API_END();
350383
}
351384

src/data/batch_utils.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
namespace xgboost::data::detail {
99
void CheckParam(BatchParam const& init, BatchParam const& param) {
1010
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
11-
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
11+
CHECK(!param.regen && param.hess.empty())
12+
<< "Only the `hist` tree method can use the `QuantileDMatrix`.";
1213
}
1314
} // namespace xgboost::data::detail

tests/python/test_data_iterator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from xgboost import testing as tm
1313
from xgboost.data import SingleBatchInternalIter as SingleBatch
1414
from xgboost.testing import IteratorForTest, make_batches, non_increasing
15-
from xgboost.testing.updater import check_quantile_loss_extmem
15+
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
1616

1717
pytestmark = tm.timeout(30)
1818

@@ -304,3 +304,13 @@ def test_quantile_objective(
304304
"approx",
305305
"cpu",
306306
)
307+
308+
309+
@given(
310+
strategies.integers(1, 4096),
311+
strategies.integers(1, 8),
312+
strategies.integers(1, 4),
313+
)
314+
@settings(deadline=None, max_examples=10, print_blob=True)
315+
def test_extmem_qdm(n_samples_per_batch: int, n_features: int, n_batches: int) -> None:
316+
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cpu", False)

0 commit comments

Comments
 (0)