Skip to content

Commit 9bc2b1b

Browse files
committed
Use column names for predict with Daru [skip ci]
1 parent 5ff5be3 commit 9bc2b1b

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

lib/lightgbm/booster.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def num_trees
141141
def predict(input, start_iteration: nil, num_iteration: nil, **params)
142142
input =
143143
if daru?(input)
144-
input.map_rows(&:to_a)
144+
input[*cached_feature_name].map_rows(&:to_a)
145145
elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name
146146
sorted_feature_values(input)
147147
elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are

test/booster_test.rb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ def test_predict_hash
5050
end
5151
end
5252

53+
def test_predict_daru
54+
x_test =
55+
Daru::DataFrame.new([
56+
{"x3" => 9.0, "x2" => 7.2, "x1" => 1.2, "x0" => 3.7},
57+
{"x3" => 0.0, "x2" => 7.9, "x1" => 0.5, "x0" => 7.5},
58+
])
59+
pred = booster.predict(x_test)
60+
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], pred.first(2)
61+
end
62+
5363
def test_predict_rover
5464
require "rover"
5565
x_test =

0 commit comments

Comments
 (0)