Skip to content

Commit 6b98882

Browse files
trivialfishcho3
andauthored
[backport] Make xgboost.testing compatible with skl 1.7 (dmlc#11502) (dmlc#11584)
Co-authored-by: Philip Hyunsu Cho <[email protected]>
1 parent 57f42c8 commit 6b98882

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

tests/python/test_updaters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import xgboost as xgb
1010
from xgboost import testing as tm
11+
from xgboost.core import _parse_version
1112
from xgboost.testing.params import (
1213
cat_parameter_strategy,
1314
exact_parameter_strategy,
@@ -375,16 +376,25 @@ def test_categorical_missing(self, rows: int, cols: int, cats: int) -> None:
375376

376377
def run_adaptive(self, tree_method, weighted) -> None:
377378
rng = np.random.RandomState(1994)
379+
from sklearn import __version__ as sklearn_version
378380
from sklearn.datasets import make_regression
379381
from sklearn.utils import stats
380382

381383
n_samples = 256
382384
X, y = make_regression(n_samples, 16, random_state=rng)
385+
(sk_major, sk_minor, _), _ = _parse_version(sklearn_version)
386+
if sk_major > 1 or sk_minor >= 7:
387+
kwargs = {"percentile_rank": 50}
388+
else:
389+
kwargs = {"percentile": 50}
390+
383391
if weighted:
384392
w = rng.normal(size=n_samples)
385393
w -= w.min()
386394
Xy = xgb.DMatrix(X, y, weight=w)
387-
base_score = stats._weighted_percentile(y, w, percentile=50)
395+
base_score = stats._weighted_percentile( # pylint: disable=protected-access
396+
y, w, **kwargs
397+
)
388398
else:
389399
Xy = xgb.DMatrix(X, y)
390400
base_score = np.median(y)

0 commit comments

Comments
 (0)