Skip to content

Commit 8009202

Browse files
committed
Added support for passing early_stopping_rounds to initializer for Scikit-Learn API
1 parent df8e539 commit 8009202

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

lib/xgboost/classifier.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def fit(x, y, eval_set: nil, early_stopping_rounds: nil, verbose: true)
1818

1919
@booster = XGBoost.train(params, dtrain,
2020
num_boost_round: @n_estimators,
21-
early_stopping_rounds: early_stopping_rounds,
21+
early_stopping_rounds: early_stopping_rounds || @early_stopping_rounds,
2222
verbose_eval: verbose,
2323
evals: evals
2424
)

lib/xgboost/model.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ module XGBoost
22
class Model
33
attr_reader :booster
44

5-
def initialize(n_estimators: 100, importance_type: "gain", **options)
5+
def initialize(n_estimators: 100, importance_type: "gain", early_stopping_rounds: nil, **options)
66
@params = options
77
@n_estimators = n_estimators
88
@importance_type = importance_type
9+
@early_stopping_rounds = early_stopping_rounds
910
end
1011

1112
def predict(data)

lib/xgboost/regressor.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def fit(x, y, eval_set: nil, early_stopping_rounds: nil, verbose: true)
1010

1111
@booster = XGBoost.train(@params, dtrain,
1212
num_boost_round: @n_estimators,
13-
early_stopping_rounds: early_stopping_rounds,
13+
early_stopping_rounds: early_stopping_rounds || @early_stopping_rounds,
1414
verbose_eval: verbose,
1515
evals: evals
1616
)

test/classifier_test.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def test_multiclass
5050
def test_early_stopping
5151
x_train, y_train, x_test, y_test = multiclass_data
5252

53-
model = XGBoost::Classifier.new
54-
model.fit(x_train, y_train, early_stopping_rounds: 5, eval_set: [[x_test, y_test]], verbose: false)
53+
model = XGBoost::Classifier.new(early_stopping_rounds: 5)
54+
model.fit(x_train, y_train, eval_set: [[x_test, y_test]], verbose: false)
5555
assert_equal 18, model.booster.best_iteration
5656
end
5757

test/regressor_test.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def test_works
2323
def test_early_stopping
2424
x_train, y_train, x_test, y_test = regression_data
2525

26-
model = XGBoost::Regressor.new
27-
model.fit(x_train, y_train, early_stopping_rounds: 5, eval_set: [[x_test, y_test]], verbose: false)
26+
model = XGBoost::Regressor.new(early_stopping_rounds: 5)
27+
model.fit(x_train, y_train, eval_set: [[x_test, y_test]], verbose: false)
2828
assert_equal 9, model.booster.best_iteration
2929
end
3030

0 commit comments

Comments
 (0)