@@ -792,19 +792,19 @@ def test_kwargs_grid_search():
792
792
from sklearn import datasets
793
793
from sklearn .model_selection import GridSearchCV
794
794
795
- params = {' tree_method' : ' hist' }
796
- clf = xgb .XGBClassifier (n_estimators = 1 , learning_rate = 1.0 , ** params )
797
- assert clf .get_params ()[' tree_method' ] == ' hist'
798
- # 'max_leaves ' is not a default argument of XGBClassifier
795
+ params = {" tree_method" : " hist" }
796
+ clf = xgb .XGBClassifier (n_estimators = 3 , ** params )
797
+ assert clf .get_params ()[" tree_method" ] == " hist"
798
+ # 'eta ' is not a default argument of XGBClassifier
799
799
# Check we can still do grid search over this parameter
800
- search_params = {'max_leaves' : range ( 2 , 5 ) }
800
+ search_params = {"eta" : [ 0 , 0. 2 , 0.4 ] }
801
801
grid_cv = GridSearchCV (clf , search_params , cv = 5 )
802
802
iris = datasets .load_iris ()
803
803
grid_cv .fit (iris .data , iris .target )
804
804
805
805
# Expect unique results for each parameter value
806
806
# This confirms sklearn is able to successfully update the parameter
807
- means = grid_cv .cv_results_ [' mean_test_score' ]
807
+ means = grid_cv .cv_results_ [" mean_test_score" ]
808
808
assert len (means ) == len (set (means ))
809
809
810
810
@@ -928,6 +928,25 @@ def save_load_model(model_path):
928
928
xgb_model = xgb .XGBModel ()
929
929
xgb_model .load_model (model_path )
930
930
931
+ clf = xgb .XGBClassifier (booster = "gblinear" , early_stopping_rounds = 1 )
932
+ clf .fit (X , y , eval_set = [(X , y )])
933
+ best_iteration = clf .best_iteration
934
+ best_score = clf .best_score
935
+ predt_0 = clf .predict (X )
936
+ clf .save_model (model_path )
937
+ clf .load_model (model_path )
938
+ predt_1 = clf .predict (X )
939
+ np .testing .assert_allclose (predt_0 , predt_1 )
940
+ assert clf .best_iteration == best_iteration
941
+ assert clf .best_score == best_score
942
+
943
+ clfpkl = pickle .dumps (clf )
944
+ clf = pickle .loads (clfpkl )
945
+ predt_2 = clf .predict (X )
946
+ np .testing .assert_allclose (predt_0 , predt_2 )
947
+ assert clf .best_iteration == best_iteration
948
+ assert clf .best_score == best_score
949
+
931
950
932
951
def test_save_load_model ():
933
952
with tempfile .TemporaryDirectory () as tempdir :
0 commit comments