Skip to content

Commit 899e4c8

Browse files
authored
[backport] Do not return internal value for get_params. (dmlc#8634) (dmlc#8642)
1 parent a2085bf commit 899e4c8

File tree

3 files changed

+73
-45
lines changed

3 files changed

+73
-45
lines changed

python-package/xgboost/sklearn.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def set_params(self, **params: Any) -> "XGBModel":
674674
self.kwargs = {}
675675
self.kwargs[key] = value
676676

677-
if hasattr(self, "_Booster"):
677+
if self.__sklearn_is_fitted__():
678678
parameters = self.get_xgb_params()
679679
self.get_booster().set_param(parameters)
680680

@@ -701,39 +701,12 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
701701
np.iinfo(np.int32).max
702702
)
703703

704-
def parse_parameter(value: Any) -> Optional[Union[int, float, str]]:
705-
for t in (int, float, str):
706-
try:
707-
ret = t(value)
708-
return ret
709-
except ValueError:
710-
continue
711-
return None
712-
713-
# Get internal parameter values
714-
try:
715-
config = json.loads(self.get_booster().save_config())
716-
stack = [config]
717-
internal = {}
718-
while stack:
719-
obj = stack.pop()
720-
for k, v in obj.items():
721-
if k.endswith("_param"):
722-
for p_k, p_v in v.items():
723-
internal[p_k] = p_v
724-
elif isinstance(v, dict):
725-
stack.append(v)
726-
727-
for k, v in internal.items():
728-
if k in params and params[k] is None:
729-
params[k] = parse_parameter(v)
730-
except ValueError:
731-
pass
732704
return params
733705

734706
def get_xgb_params(self) -> Dict[str, Any]:
735707
"""Get xgboost specific parameters."""
736-
params = self.get_params()
708+
params: Dict[str, Any] = self.get_params()
709+
737710
# Parameters that should not go into native learner.
738711
wrapper_specific = {
739712
"importance_type",
@@ -750,6 +723,7 @@ def get_xgb_params(self) -> Dict[str, Any]:
750723
for k, v in params.items():
751724
if k not in wrapper_specific and not callable(v):
752725
filtered[k] = v
726+
753727
return filtered
754728

755729
def get_num_boosting_rounds(self) -> int:
@@ -1070,7 +1044,7 @@ def _can_use_inplace_predict(self) -> bool:
10701044
# error with incompatible data type.
10711045
# Inplace predict doesn't handle as many data types as DMatrix, but it's
10721046
# sufficient for dask interface where input is simpiler.
1073-
predictor = self.get_params().get("predictor", None)
1047+
predictor = self.get_xgb_params().get("predictor", None)
10741048
if predictor in ("auto", None) and self.booster != "gblinear":
10751049
return True
10761050
return False
@@ -1336,7 +1310,7 @@ def coef_(self) -> np.ndarray:
13361310
-------
13371311
coef_ : array of shape ``[n_features]`` or ``[n_classes, n_features]``
13381312
"""
1339-
if self.get_params()["booster"] != "gblinear":
1313+
if self.get_xgb_params()["booster"] != "gblinear":
13401314
raise AttributeError(
13411315
f"Coefficients are not defined for Booster type {self.booster}"
13421316
)
@@ -1366,7 +1340,7 @@ def intercept_(self) -> np.ndarray:
13661340
-------
13671341
intercept_ : array of shape ``(1,)`` or ``[n_classes]``
13681342
"""
1369-
if self.get_params()["booster"] != "gblinear":
1343+
if self.get_xgb_params()["booster"] != "gblinear":
13701344
raise AttributeError(
13711345
f"Intercept (bias) is not defined for Booster type {self.booster}"
13721346
)

tests/python/test_with_pandas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def test_pandas(self):
112112

113113
# test Index as columns
114114
df = pd.DataFrame([[1, 1.1], [2, 2.2]], columns=pd.Index([1, 2]))
115-
print(df.columns, isinstance(df.columns, pd.Index))
116115
Xy = xgb.DMatrix(df)
117116
np.testing.assert_equal(np.array(Xy.feature_names), np.array(["1", "2"]))
118117

tests/python/test_with_sklearn.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib.util
33
import json
44
import os
5+
import pickle
56
import random
67
import tempfile
78
from typing import Callable, Optional
@@ -636,26 +637,74 @@ def test_sklearn_n_jobs():
636637

637638
def test_parameters_access():
638639
from sklearn import datasets
639-
params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1}
640+
641+
params = {"updater": "grow_gpu_hist", "subsample": 0.5, "n_jobs": -1}
640642
clf = xgb.XGBClassifier(n_estimators=1000, **params)
641-
assert clf.get_params()['updater'] == 'grow_gpu_hist'
642-
assert clf.get_params()['subsample'] == .5
643-
assert clf.get_params()['n_estimators'] == 1000
643+
assert clf.get_params()["updater"] == "grow_gpu_hist"
644+
assert clf.get_params()["subsample"] == 0.5
645+
assert clf.get_params()["n_estimators"] == 1000
644646

645647
clf = xgb.XGBClassifier(n_estimators=1, nthread=4)
646648
X, y = datasets.load_iris(return_X_y=True)
647649
clf.fit(X, y)
648650

649651
config = json.loads(clf.get_booster().save_config())
650-
assert int(config['learner']['generic_param']['nthread']) == 4
652+
assert int(config["learner"]["generic_param"]["nthread"]) == 4
651653

652654
clf.set_params(nthread=16)
653655
config = json.loads(clf.get_booster().save_config())
654-
assert int(config['learner']['generic_param']['nthread']) == 16
656+
assert int(config["learner"]["generic_param"]["nthread"]) == 16
655657

656658
clf.predict(X)
657659
config = json.loads(clf.get_booster().save_config())
658-
assert int(config['learner']['generic_param']['nthread']) == 16
660+
assert int(config["learner"]["generic_param"]["nthread"]) == 16
661+
662+
clf = xgb.XGBClassifier(n_estimators=2)
663+
assert clf.tree_method is None
664+
assert clf.get_params()["tree_method"] is None
665+
clf.fit(X, y)
666+
assert clf.get_params()["tree_method"] is None
667+
668+
def save_load(clf: xgb.XGBClassifier) -> xgb.XGBClassifier:
669+
with tempfile.TemporaryDirectory() as tmpdir:
670+
path = os.path.join(tmpdir, "model.json")
671+
clf.save_model(path)
672+
clf = xgb.XGBClassifier()
673+
clf.load_model(path)
674+
return clf
675+
676+
def get_tm(clf: xgb.XGBClassifier) -> str:
677+
tm = json.loads(clf.get_booster().save_config())["learner"]["gradient_booster"][
678+
"gbtree_train_param"
679+
]["tree_method"]
680+
return tm
681+
682+
assert get_tm(clf) == "exact"
683+
684+
clf = pickle.loads(pickle.dumps(clf))
685+
686+
assert clf.tree_method is None
687+
assert clf.n_estimators == 2
688+
assert clf.get_params()["tree_method"] is None
689+
assert clf.get_params()["n_estimators"] == 2
690+
assert get_tm(clf) == "exact" # preserved for pickle
691+
692+
clf = save_load(clf)
693+
694+
assert clf.tree_method is None
695+
assert clf.n_estimators == 2
696+
assert clf.get_params()["tree_method"] is None
697+
assert clf.get_params()["n_estimators"] == 2
698+
assert get_tm(clf) == "auto" # discarded for save/load_model
699+
700+
clf.set_params(tree_method="hist")
701+
assert clf.get_params()["tree_method"] == "hist"
702+
clf = pickle.loads(pickle.dumps(clf))
703+
assert clf.get_params()["tree_method"] == "hist"
704+
clf = save_load(clf)
705+
# FIXME(jiamingy): We should remove this behavior once we remove parameters
706+
# serialization for skl save/load_model.
707+
assert clf.get_params()["tree_method"] == "hist"
659708

660709

661710
def test_kwargs_error():
@@ -695,13 +744,19 @@ def test_sklearn_clone():
695744

696745
def test_sklearn_get_default_params():
697746
from sklearn.datasets import load_digits
747+
698748
digits_2class = load_digits(n_class=2)
699-
X = digits_2class['data']
700-
y = digits_2class['target']
749+
X = digits_2class["data"]
750+
y = digits_2class["target"]
701751
cls = xgb.XGBClassifier()
702-
assert cls.get_params()['base_score'] is None
752+
assert cls.get_params()["base_score"] is None
703753
cls.fit(X[:4, ...], y[:4, ...])
704-
assert cls.get_params()['base_score'] is not None
754+
base_score = float(
755+
json.loads(cls.get_booster().save_config())["learner"]["learner_model_param"][
756+
"base_score"
757+
]
758+
)
759+
np.testing.assert_equal(base_score, 0.5)
705760

706761

707762
def run_validation_weights(model):

0 commit comments

Comments
 (0)