Skip to content

Commit 5b7178a

Browse files
committed
Memoize params
1 parent 1506387 commit 5b7178a

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

lib/lightgbm/booster.rb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LightGBM
22
class Booster
33
include Utils
44

5-
attr_accessor :best_iteration, :train_data_name
5+
attr_accessor :best_iteration, :train_data_name, :params
66

77
def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
88
if model_str
@@ -13,6 +13,10 @@ def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
1313
safe_call FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, handle)
1414
end
1515
@pandas_categorical = load_pandas_categorical(file_name: model_file)
16+
if params
17+
warn "[xgboost] Ignoring params argument, using parameters from model file."
18+
end
19+
@params = loaded_param
1620
else
1721
params ||= {}
1822
set_verbosity(params)
@@ -98,6 +102,7 @@ def model_from_string(model_str)
98102
safe_call FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, handle)
99103
end
100104
@pandas_categorical = load_pandas_categorical(model_str: model_str)
105+
@params = loaded_param
101106
@cached_feature_name = nil
102107
self
103108
end

lib/lightgbm/inner_predictor.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_
5757
if @pandas_categorical&.any?
5858
apply_pandas_categorical(
5959
data,
60-
@booster.send(:loaded_param)["categorical_feature"],
60+
@booster.params["categorical_feature"],
6161
@pandas_categorical
6262
)
6363
end

lib/lightgbm/model.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def save_model(fname)
1616
end
1717

1818
def load_model(fname)
19-
@booster = Booster.new(params: @params, model_file: fname)
19+
@booster = Booster.new(model_file: fname)
2020
end
2121

2222
def best_iteration

0 commit comments

Comments
 (0)