Skip to content

Commit 801116c

Browse files
authored
Test scikit-learn model IO with gblinear. (dmlc#9459)
1 parent bb56183 commit 801116c

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

tests/python/test_with_sklearn.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -792,19 +792,19 @@ def test_kwargs_grid_search():
792792
from sklearn import datasets
793793
from sklearn.model_selection import GridSearchCV
794794

795-
params = {'tree_method': 'hist'}
796-
clf = xgb.XGBClassifier(n_estimators=1, learning_rate=1.0, **params)
797-
assert clf.get_params()['tree_method'] == 'hist'
798-
# 'max_leaves' is not a default argument of XGBClassifier
795+
params = {"tree_method": "hist"}
796+
clf = xgb.XGBClassifier(n_estimators=3, **params)
797+
assert clf.get_params()["tree_method"] == "hist"
798+
# 'eta' is not a default argument of XGBClassifier
799799
# Check we can still do grid search over this parameter
800-
search_params = {'max_leaves': range(2, 5)}
800+
search_params = {"eta": [0, 0.2, 0.4]}
801801
grid_cv = GridSearchCV(clf, search_params, cv=5)
802802
iris = datasets.load_iris()
803803
grid_cv.fit(iris.data, iris.target)
804804

805805
# Expect unique results for each parameter value
806806
# This confirms sklearn is able to successfully update the parameter
807-
means = grid_cv.cv_results_['mean_test_score']
807+
means = grid_cv.cv_results_["mean_test_score"]
808808
assert len(means) == len(set(means))
809809

810810

@@ -928,6 +928,25 @@ def save_load_model(model_path):
928928
xgb_model = xgb.XGBModel()
929929
xgb_model.load_model(model_path)
930930

931+
clf = xgb.XGBClassifier(booster="gblinear", early_stopping_rounds=1)
932+
clf.fit(X, y, eval_set=[(X, y)])
933+
best_iteration = clf.best_iteration
934+
best_score = clf.best_score
935+
predt_0 = clf.predict(X)
936+
clf.save_model(model_path)
937+
clf.load_model(model_path)
938+
predt_1 = clf.predict(X)
939+
np.testing.assert_allclose(predt_0, predt_1)
940+
assert clf.best_iteration == best_iteration
941+
assert clf.best_score == best_score
942+
943+
clfpkl = pickle.dumps(clf)
944+
clf = pickle.loads(clfpkl)
945+
predt_2 = clf.predict(X)
946+
np.testing.assert_allclose(predt_0, predt_2)
947+
assert clf.best_iteration == best_iteration
948+
assert clf.best_score == best_score
949+
931950

932951
def test_save_load_model():
933952
with tempfile.TemporaryDirectory() as tempdir:

0 commit comments

Comments
 (0)