11require_relative "test_helper"
22
33class BoosterTest < Minitest ::Test
4- def test_model_file
4+ def test_predict
55 x_test = [ [ 3.7 , 1.2 , 7.2 , 9.0 ] , [ 7.5 , 0.5 , 7.9 , 0.0 ] ]
66 booster = LightGBM ::Booster . new ( model_file : "test/support/model.txt" )
77 y_pred = booster . predict ( x_test )
@@ -23,21 +23,6 @@ def test_model_from_string
2323 assert_elements_in_delta [ 0.9823112229173586 , 0.9583143724610858 ] , y_pred . first ( 2 )
2424 end
2525
26- def test_feature_importance
27- assert_equal [ 280 , 285 , 335 , 148 ] , booster . feature_importance
28- end
29-
30- def test_feature_name
31- assert_equal [ "x0" , "x1" , "x2" , "x3" ] , booster . feature_name
32- end
33-
34- def test_feature_importance_bad_importance_type
35- error = assert_raises ( LightGBM ::Error ) do
36- booster . feature_importance ( importance_type : "bad" )
37- end
38- assert_includes error . message , "Unknown importance type"
39- end
40-
4126 def test_predict_hash
4227 pred = booster . predict ( { x0 : 3.7 , x1 : 1.2 , x2 : 7.2 , x3 : 9.0 } )
4328 assert_in_delta 0.9823112229173586 , pred
@@ -88,6 +73,68 @@ def test_predict_rover
8873 end
8974 end
9075
76+ def test_predict_type_leaf_index
77+ x_test = [ [ 3.7 , 1.2 , 7.2 , 9.0 ] , [ 7.5 , 0.5 , 7.9 , 0.0 ] ]
78+ leaf_indexes = booster . predict ( x_test , predict_type : LightGBM ::C_API_PREDICT_LEAF_INDEX )
79+ assert_equal 200 , leaf_indexes . count
80+ assert_equal 9.0 , leaf_indexes . first
81+ assert_equal 7.0 , leaf_indexes . last
82+
83+ x_test = [ 3.7 , 1.2 , 7.2 , 9.0 ]
84+ leaf_indexes = booster . predict ( x_test , predict_type : LightGBM ::C_API_PREDICT_LEAF_INDEX )
85+ assert_equal 100 , leaf_indexes . count
86+ assert_equal 9.0 , leaf_indexes . first
87+ assert_equal 10.0 , leaf_indexes . last
88+ end
89+
90+ def test_predict_type_contrib
91+ x_test = [ [ 3.7 , 1.2 , 7.2 , 9.0 ] , [ 7.5 , 0.5 , 7.9 , 0.0 ] ]
92+ results = booster . predict ( x_test , predict_type : LightGBM ::C_API_PREDICT_CONTRIB )
93+ assert_equal 10 , results . count
94+
95+ # split results on num_features + 1
96+ predictions = results . each_slice ( 5 ) . to_a
97+ shap_values_1 = predictions . first [ 0 ..-2 ]
98+ ypred_1 = predictions . first [ -1 ]
99+ assert_elements_in_delta [
100+ -0.0733949225678886 , -0.24289592050101766 , 0.24183795683166504 , 0.063430775771174
101+ ] , shap_values_1
102+ assert_in_delta ( 0.9933333333834246 ) , ypred_1
103+
104+ shap_values_2 = predictions . last [ 0 ..-2 ]
105+ ypred_2 = predictions . last [ -1 ]
106+ assert_elements_in_delta [
107+ 0.1094902954684793 , -0.2810485083947154 , 0.26691627597706397 , -0.13037702397316747
108+ ] , shap_values_2
109+ assert_in_delta ( 0.9933333333834246 ) , ypred_2
110+
111+ # single row
112+ x_test = [ 3.7 , 1.2 , 7.2 , 9.0 ]
113+ results = booster . predict ( x_test , predict_type : LightGBM ::C_API_PREDICT_CONTRIB )
114+ assert_equal 5 , results . count
115+ shap_values = results [ 0 ..-2 ]
116+ ypred = results [ -1 ]
117+ assert_elements_in_delta [
118+ -0.0733949225678886 , -0.24289592050101766 , 0.24183795683166504 , 0.063430775771174
119+ ] , shap_values
120+ assert_in_delta ( 0.9933333333834246 ) , ypred
121+ end
122+
123+ def test_feature_importance
124+ assert_equal [ 280 , 285 , 335 , 148 ] , booster . feature_importance
125+ end
126+
127+ def test_feature_name
128+ assert_equal [ "x0" , "x1" , "x2" , "x3" ] , booster . feature_name
129+ end
130+
131+ def test_feature_importance_bad_importance_type
132+ error = assert_raises ( LightGBM ::Error ) do
133+ booster . feature_importance ( importance_type : "bad" )
134+ end
135+ assert_includes error . message , "Unknown importance type"
136+ end
137+
91138 def test_model_to_string
92139 assert booster . model_to_string
93140 end
0 commit comments