diff --git a/lib/lightgbm/booster.rb b/lib/lightgbm/booster.rb index 4cd09f8..d451fa6 100644 --- a/lib/lightgbm/booster.rb +++ b/lib/lightgbm/booster.rb @@ -141,6 +141,10 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params) input = if daru?(input) input.map_rows(&:to_a) + elsif input.is_a?(Hash) # sort feature.values to match the order of model.feature_name + sorted_feature_values(input) + elsif input.is_a?(Array) && input.first.is_a?(Hash) # on multiple elems, if 1st is hash, assume they all are + input.map(&method(:sorted_feature_values)) else input.to_a end @@ -241,6 +245,11 @@ def read_int64(ptr) ptr.read_array_of_int64(1).first end + def sorted_feature_values(input_hash) + @cached_feature_names ||= feature_name + input_hash.transform_keys(&:to_s).fetch_values(*@cached_feature_names) + end + include Utils end end diff --git a/test/booster_test.rb b/test/booster_test.rb index b317587..e513163 100644 --- a/test/booster_test.rb +++ b/test/booster_test.rb @@ -30,6 +30,22 @@ def test_feature_importance_bad_importance_type assert_includes error.message, "Unknown importance type" end + def test_predict_with_hash_builds_sorted_input + pred = booster.predict({x0: 3.7, x1: 1.2, x2: 7.2, x3: 9.0}) + assert_in_delta 0.9823112229173586, pred + + pred = booster.predict({"x3" => 9.0, "x2" => 7.2, "x1" => 1.2, "x0" => 3.7}) + assert_in_delta 0.9823112229173586, pred + + pred = booster.predict( + [ + {"x3" => 9.0, "x2" => 7.2, "x1" => 1.2, "x0" => 3.7}, + {"x3" => 0.0, "x2" => 7.9, "x1" => 0.5, "x0" => 7.5}, + ] + ) + assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], pred.first(2) + end + def test_model_to_string assert booster.model_to_string end