Skip to content

Commit 1ef6902

Browse files
committed
Add support for different prediction types
1 parent fca59ef commit 1ef6902

File tree

5 files changed

+126
-20
lines changed

5 files changed

+126
-20
lines changed

lib/lightgbm.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# modules
55
require_relative "lightgbm/utils"
6+
require_relative "lightgbm/macros"
67
require_relative "lightgbm/booster"
78
require_relative "lightgbm/dataset"
89
require_relative "lightgbm/version"

lib/lightgbm/booster.rb

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,21 @@ def num_trees
141141
out.read_int
142142
end
143143

144-
# TODO support different prediction types
145-
def predict(input, start_iteration: nil, num_iteration: nil, **params)
144+
145+
# Make prediction for a new dataset.
146+
# C-API: LGBM_BoosterPredictForMat.
147+
#
148+
# @param input [Array, Array<Array,Hash>, Hash{String => Numeric, String}, Daru::DataFrame, Rover::DataFrame] Input data
149+
# @param start_iteration [Integer] Start index of the iteration to predict
150+
# @param num_iteration [Integer] Number of iteration for prediction, <= 0 means no limit
151+
# @param predict_type [Integer] What should be predicted
152+
# - C_API_PREDICT_NORMAL: normal prediction, with transform (if needed);
153+
# - C_API_PREDICT_RAW_SCORE: raw score;
154+
# - C_API_PREDICT_LEAF_INDEX: leaf index;
155+
# - C_API_PREDICT_CONTRIB: feature contributions (SHAP values)
156+
# @param **params [Hash] Other parameters for prediction, e.g. early stopping for prediction
157+
# @return [Float, Array<Float>] Prediction results
158+
def predict(input, start_iteration: nil, num_iteration: nil, predict_type: C_API_PREDICT_NORMAL, **params)
146159
input =
147160
if daru?(input)
148161
input[*cached_feature_name].map_rows(&:to_a)
@@ -170,14 +183,51 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
170183
data.write_array_of_double(flat_input)
171184

172185
out_len = ::FFI::MemoryPointer.new(:int64)
173-
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
174-
check_result FFI.LGBM_BoosterPredictForMat(handle_pointer, data, 1, input.count, input.first.count, 1, 0, start_iteration, num_iteration, params_str(params), out_len, out_result)
186+
case predict_type
187+
when C_API_PREDICT_NORMAL, C_API_PREDICT_RAW_SCORE
188+
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count)
189+
when C_API_PREDICT_LEAF_INDEX
190+
num_predict = num_preds(start_iteration:, num_iteration:, nrow: input.count, predict_type:)
191+
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * num_predict)
192+
singular = false
193+
when C_API_PREDICT_CONTRIB
194+
out_result = ::FFI::MemoryPointer.new(:double, num_class * input.count * (num_feature + 1))
195+
singular = false
196+
end
197+
198+
check_result FFI.LGBM_BoosterPredictForMat(
199+
handle_pointer,
200+
data,
201+
1,
202+
input.count,
203+
input.first.count,
204+
1,
205+
predict_type,
206+
start_iteration,
207+
num_iteration,
208+
params_str(params),
209+
out_len,
210+
out_result
211+
)
175212
out = out_result.read_array_of_double(out_len.read_int64)
176213
out = out.each_slice(num_class).to_a if num_class > 1
177214

178215
singular ? out.first : out
179216
end
180217

218+
def num_preds(start_iteration: 0, num_iteration: best_iteration, nrow: nil, predict_type: C_API_PREDICT_NORMAL)
219+
out_len = ::FFI::MemoryPointer.new(:int64)
220+
check_result FFI.LGBM_BoosterCalcNumPredict(
221+
handle_pointer,
222+
nrow,
223+
predict_type,
224+
start_iteration,
225+
num_iteration,
226+
out_len
227+
)
228+
out_len.read_int64
229+
end
230+
181231
def save_model(filename, num_iteration: nil, start_iteration: 0)
182232
num_iteration ||= best_iteration
183233
feature_importance_type = 0 # TODO add

lib/lightgbm/ffi.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ module FFI
3838
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
3939
attach_function :LGBM_BoosterFree, %i[pointer], :int
4040
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
41+
attach_function :LGBM_BoosterCalcNumPredict, %i[pointer int int int int pointer], :int
4142
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int
4243
attach_function :LGBM_BoosterUpdateOneIter, %i[pointer pointer], :int
4344
attach_function :LGBM_BoosterGetCurrentIteration, %i[pointer pointer], :int

lib/lightgbm/macros.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module LightGBM
2+
# Macro definition of prediction type in C API of LightGBM
3+
C_API_PREDICT_NORMAL = 0
4+
C_API_PREDICT_RAW_SCORE = 1
5+
C_API_PREDICT_LEAF_INDEX = 2
6+
C_API_PREDICT_CONTRIB = 3
7+
end

test/booster_test.rb

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
require_relative "test_helper"
22

33
class BoosterTest < Minitest::Test
4-
def test_model_file
4+
def test_predict
55
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
66
booster = LightGBM::Booster.new(model_file: "test/support/model.txt")
77
y_pred = booster.predict(x_test)
@@ -23,21 +23,6 @@ def test_model_from_string
2323
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
2424
end
2525

26-
def test_feature_importance
27-
assert_equal [280, 285, 335, 148], booster.feature_importance
28-
end
29-
30-
def test_feature_name
31-
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
32-
end
33-
34-
def test_feature_importance_bad_importance_type
35-
error = assert_raises(LightGBM::Error) do
36-
booster.feature_importance(importance_type: "bad")
37-
end
38-
assert_includes error.message, "Unknown importance type"
39-
end
40-
4126
def test_predict_hash
4227
pred = booster.predict({x0: 3.7, x1: 1.2, x2: 7.2, x3: 9.0})
4328
assert_in_delta 0.9823112229173586, pred
@@ -88,6 +73,68 @@ def test_predict_rover
8873
end
8974
end
9075

76+
def test_predict_type_leaf_index
77+
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
78+
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
79+
assert_equal 200, leaf_indexes.count
80+
assert_equal 9.0, leaf_indexes.first
81+
assert_equal 7.0, leaf_indexes.last
82+
83+
x_test = [3.7, 1.2, 7.2, 9.0]
84+
leaf_indexes = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_LEAF_INDEX)
85+
assert_equal 100, leaf_indexes.count
86+
assert_equal 9.0, leaf_indexes.first
87+
assert_equal 10.0, leaf_indexes.last
88+
end
89+
90+
def test_predict_type_contrib
91+
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
92+
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
93+
assert_equal 10, results.count
94+
95+
# split results on num_features + 1
96+
predictions = results.each_slice(5).to_a
97+
shap_values_1 = predictions.first[0..-2]
98+
ypred_1 = predictions.first[-1]
99+
assert_elements_in_delta [
100+
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
101+
], shap_values_1
102+
assert_in_delta (0.9933333333834246), ypred_1
103+
104+
shap_values_2 = predictions.last[0..-2]
105+
ypred_2 = predictions.last[-1]
106+
assert_elements_in_delta [
107+
0.1094902954684793, -0.2810485083947154, 0.26691627597706397, -0.13037702397316747
108+
], shap_values_2
109+
assert_in_delta (0.9933333333834246), ypred_2
110+
111+
# single row
112+
x_test = [3.7, 1.2, 7.2, 9.0]
113+
results = booster.predict(x_test, predict_type: LightGBM::C_API_PREDICT_CONTRIB)
114+
assert_equal 5, results.count
115+
shap_values = results[0..-2]
116+
ypred = results[-1]
117+
assert_elements_in_delta [
118+
-0.0733949225678886, -0.24289592050101766, 0.24183795683166504, 0.063430775771174
119+
], shap_values
120+
assert_in_delta (0.9933333333834246), ypred
121+
end
122+
123+
def test_feature_importance
124+
assert_equal [280, 285, 335, 148], booster.feature_importance
125+
end
126+
127+
def test_feature_name
128+
assert_equal ["x0", "x1", "x2", "x3"], booster.feature_name
129+
end
130+
131+
def test_feature_importance_bad_importance_type
132+
error = assert_raises(LightGBM::Error) do
133+
booster.feature_importance(importance_type: "bad")
134+
end
135+
assert_includes error.message, "Unknown importance type"
136+
end
137+
91138
def test_model_to_string
92139
assert booster.model_to_string
93140
end

0 commit comments

Comments
 (0)