Skip to content

Commit 2f8bc98

Browse files
committed
Added importance_type option to dump_model, model_to_string, and save_model methods
1 parent 923528e commit 2f8bc98

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Added support for different prediction types
44
- Added support for hashes and Rover data frames to `predict` method
55
- Added support for hashes to `Dataset`
6+
- Added `importance_type` option to `dump_model`, `model_to_string`, and `save_model` methods
67
- Changed `Dataset` to use column names for feature names with Rover and Daru
78
- Changed `predict` method to match feature names with Daru
89
- Dropped support for Ruby < 3.1

lib/lightgbm/booster.rb

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def current_iteration
3838
out.read_int
3939
end
4040

41-
def dump_model(num_iteration: nil, start_iteration: 0)
41+
def dump_model(num_iteration: nil, start_iteration: 0, importance_type: "split")
4242
num_iteration ||= best_iteration
43+
importance_type_int = feature_importance_type_mapper(importance_type)
4344
buffer_len = 1 << 20
4445
out_len = ::FFI::MemoryPointer.new(:int64)
4546
out_str = ::FFI::MemoryPointer.new(:char, buffer_len)
46-
feature_importance_type = 0 # TODO add option
47-
safe_call FFI.LGBM_BoosterDumpModel(@handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, out_str)
47+
safe_call FFI.LGBM_BoosterDumpModel(@handle, start_iteration, num_iteration, importance_type_int, buffer_len, out_len, out_str)
4848
actual_len = out_len.read_int64
4949
if actual_len > buffer_len
5050
out_str = ::FFI::MemoryPointer.new(:char, actual_len)
51-
safe_call FFI.LGBM_BoosterDumpModel(@handle, start_iteration, num_iteration, feature_importance_type, actual_len, out_len, out_str)
51+
safe_call FFI.LGBM_BoosterDumpModel(@handle, start_iteration, num_iteration, importance_type_int, actual_len, out_len, out_str)
5252
end
5353
out_str.read_string
5454
end
@@ -64,19 +64,10 @@ def eval_train
6464

6565
def feature_importance(iteration: nil, importance_type: "split")
6666
iteration ||= best_iteration
67-
importance_type =
68-
case importance_type
69-
when "split"
70-
FFI::C_API_FEATURE_IMPORTANCE_SPLIT
71-
when "gain"
72-
FFI::C_API_FEATURE_IMPORTANCE_GAIN
73-
else
74-
-1
75-
end
76-
67+
importance_type_int = feature_importance_type_mapper(importance_type)
7768
num_feature = self.num_feature
7869
out_result = ::FFI::MemoryPointer.new(:double, num_feature)
79-
safe_call FFI.LGBM_BoosterFeatureImportance(@handle, iteration, importance_type, out_result)
70+
safe_call FFI.LGBM_BoosterFeatureImportance(@handle, iteration, importance_type_int, out_result)
8071
out_result.read_array_of_double(num_feature).map(&:to_i)
8172
end
8273

@@ -109,17 +100,17 @@ def model_from_string(model_str)
109100
self
110101
end
111102

112-
def model_to_string(num_iteration: nil, start_iteration: 0)
103+
def model_to_string(num_iteration: nil, start_iteration: 0, importance_type: "split")
113104
num_iteration ||= best_iteration
105+
importance_type_int = feature_importance_type_mapper(importance_type)
114106
buffer_len = 1 << 20
115107
out_len = ::FFI::MemoryPointer.new(:int64)
116108
out_str = ::FFI::MemoryPointer.new(:char, buffer_len)
117-
feature_importance_type = 0 # TODO add option
118-
safe_call FFI.LGBM_BoosterSaveModelToString(@handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, out_str)
109+
safe_call FFI.LGBM_BoosterSaveModelToString(@handle, start_iteration, num_iteration, importance_type_int, buffer_len, out_len, out_str)
119110
actual_len = out_len.read_int64
120111
if actual_len > buffer_len
121112
out_str = ::FFI::MemoryPointer.new(:char, actual_len)
122-
safe_call FFI.LGBM_BoosterSaveModelToString(@handle, start_iteration, num_iteration, feature_importance_type, actual_len, out_len, out_str)
113+
safe_call FFI.LGBM_BoosterSaveModelToString(@handle, start_iteration, num_iteration, importance_type_int, actual_len, out_len, out_str)
123114
end
124115
out_str.read_string
125116
end
@@ -162,10 +153,10 @@ def predict(data, start_iteration: 0, num_iteration: nil, raw_score: false, pred
162153
)
163154
end
164155

165-
def save_model(filename, num_iteration: nil, start_iteration: 0)
156+
def save_model(filename, num_iteration: nil, start_iteration: 0, importance_type: "split")
166157
num_iteration ||= best_iteration
167-
feature_importance_type = 0 # TODO add
168-
safe_call FFI.LGBM_BoosterSaveModel(@handle, start_iteration, num_iteration, feature_importance_type, filename)
158+
importance_type_int = feature_importance_type_mapper(importance_type)
159+
safe_call FFI.LGBM_BoosterSaveModel(@handle, start_iteration, num_iteration, importance_type_int, filename)
169160
self # consistent with Python API
170161
end
171162

@@ -233,5 +224,16 @@ def num_class
233224
def cached_feature_name
234225
@cached_feature_name ||= feature_name
235226
end
227+
228+
def feature_importance_type_mapper(importance_type)
229+
case importance_type
230+
when "split"
231+
FFI::C_API_FEATURE_IMPORTANCE_SPLIT
232+
when "gain"
233+
FFI::C_API_FEATURE_IMPORTANCE_GAIN
234+
else
235+
-1
236+
end
237+
end
236238
end
237239
end

0 commit comments

Comments
 (0)