Skip to content

Commit 2258bc8

Browse files
authored
Add more tests and doc for QDM. (dmlc#10692)
1 parent 582ea10 commit 2258bc8

File tree

7 files changed

+61
-4
lines changed

7 files changed

+61
-4
lines changed

python-package/xgboost/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,20 @@ class QuantileDMatrix(DMatrix):
15221522
15231523
.. versionadded:: 1.7.0
15241524
1525+
Examples
1526+
--------
1527+
1528+
.. code-block::
1529+
1530+
from sklearn.datasets import make_regression
1531+
from sklearn.model_selection import train_test_split
1532+
1533+
X, y = make_regression()
1534+
X_train, X_test, y_train, y_test = train_test_split(X, y)
1535+
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
1536+
# It's necessary to have the training DMatrix as a reference for valid quantiles.
1537+
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy_train)
1538+
15251539
Parameters
15261540
----------
15271541
max_bin :
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""QuantileDMatrix related tests."""
2+
3+
import numpy as np
4+
from sklearn.model_selection import train_test_split
5+
6+
import xgboost as xgb
7+
8+
from .data import make_batches
9+
10+
11+
def check_ref_quantile_cut(device: str) -> None:
12+
"""Check obtaining the same cut values given a reference."""
13+
X, y, _ = (
14+
data[0]
15+
for data in make_batches(
16+
n_samples_per_batch=8192,
17+
n_features=16,
18+
n_batches=1,
19+
use_cupy=device.startswith("cuda"),
20+
)
21+
)
22+
23+
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
24+
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
25+
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid, ref=Xy_train)
26+
27+
cut_train = Xy_train.get_quantile_cut()
28+
cut_valid = Xy_valid.get_quantile_cut()
29+
30+
np.testing.assert_allclose(cut_train[0], cut_valid[0])
31+
np.testing.assert_allclose(cut_train[1], cut_valid[1])
32+
33+
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
34+
cut_valid = Xy_valid.get_quantile_cut()
35+
assert not np.allclose(cut_train[1], cut_valid[1])

python-package/xgboost/testing/updater.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
250250
check_cut(n_entries, indptr, data, X.dtypes)
251251

252252

253-
def check_get_quantile_cut(tree_method: str) -> None:
253+
def check_get_quantile_cut(tree_method: str, device: str) -> None:
254254
"""Check the quantile cut getter."""
255255

256-
use_cupy = tree_method == "gpu_hist"
256+
use_cupy = device.startswith("cuda")
257257
check_get_quantile_cut_device(tree_method, False)
258258
if use_cupy:
259259
check_get_quantile_cut_device(tree_method, True)

tests/python-gpu/test_device_quantile_dmatrix.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from xgboost import testing as tm
99
from xgboost.testing.data import check_inf
1010
from xgboost.testing.data_iter import run_mixed_sparsity
11+
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
1112

1213
sys.path.append("tests/python")
1314
import test_quantile_dmatrix as tqd
@@ -142,6 +143,9 @@ def test_interoperability(self, tree_method: str, max_bin: int) -> None:
142143
{"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4
143144
)
144145

146+
def test_ref_quantile_cut(self) -> None:
147+
check_ref_quantile_cut("cuda")
148+
145149
@pytest.mark.skipif(**tm.no_cupy())
146150
def test_metainfo(self) -> None:
147151
import cupy as cp

tests/python-gpu/test_gpu_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,4 +321,4 @@ def test_issue8824(self):
321321

322322
@pytest.mark.skipif(**tm.no_cudf())
323323
def test_get_quantile_cut(self) -> None:
324-
check_get_quantile_cut("gpu_hist")
324+
check_get_quantile_cut("hist", "cuda")

tests/python/test_quantile_dmatrix.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from xgboost.testing.data import check_inf, np_dtypes
1919
from xgboost.testing.data_iter import run_mixed_sparsity
20+
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
2021

2122

2223
class TestQuantileDMatrix:
@@ -266,6 +267,9 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
266267
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
267268
)
268269

270+
def test_ref_quantile_cut(self) -> None:
271+
check_ref_quantile_cut("cpu")
272+
269273
def test_ref_dmatrix(self) -> None:
270274
rng = np.random.RandomState(1994)
271275
self.run_ref_dmatrix(rng, "hist", True)

tests/python/test_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,4 +412,4 @@ def test_quantile_loss(self, weighted: bool) -> None:
412412
@pytest.mark.skipif(**tm.no_pandas())
413413
@pytest.mark.parametrize("tree_method", ["hist"])
414414
def test_get_quantile_cut(self, tree_method: str) -> None:
415-
check_get_quantile_cut(tree_method)
415+
check_get_quantile_cut(tree_method, "cpu")

0 commit comments

Comments
 (0)