Skip to content

Commit 588ec93

Browse files
committed
Improved handling of new and invalid categories
1 parent 33f0369 commit 588ec93

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

lib/lightgbm/inner_predictor.rb

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,17 @@ def cached_feature_name
141141
def apply_pandas_categorical(data, categorical_feature, pandas_categorical)
142142
(categorical_feature || []).each_with_index do |cf, i|
143143
cat_codes = pandas_categorical[i].map.with_index.to_h
144-
# TODO confirm column is categorical
145144
data.each do |r|
146-
# TODO decide how to handle missing values
147-
r[cf] = cat_codes.fetch(r[cf])
145+
cat = r[cf]
146+
unless cat.nil?
147+
r[cf] =
148+
cat_codes.fetch(cat) do
149+
unless cat.is_a?(String)
150+
raise ArgumentError, "expected categorical value"
151+
end
152+
nil
153+
end
154+
end
148155
end
149156
end
150157
end

test/booster_test.rb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ def test_predict_pandas_categorical_model_str
144144
assert_elements_in_delta [0.996415541144579, 1.0809369939979934], y_pred.first(2)
145145
end
146146

147+
def test_predict_pandas_categorical_missing_category
148+
booster = LightGBM::Booster.new(model_file: "test/support/categorical.txt")
149+
assert_in_delta 0.996415541144579, booster.predict([3.7, 1.2, 7.2, nil])
150+
end
151+
152+
def test_predict_pandas_categorical_new_category
153+
booster = LightGBM::Booster.new(model_file: "test/support/categorical.txt")
154+
assert_in_delta 0.996415541144579, booster.predict([3.7, 1.2, 7.2, "cat10"])
155+
end
156+
157+
def test_predict_pandas_categorical_invalid_category
158+
booster = LightGBM::Booster.new(model_file: "test/support/categorical.txt")
159+
error = assert_raises(ArgumentError) do
160+
booster.predict([7.5, 0.5, 7.9, true])
161+
end
162+
assert_equal "expected categorical value", error.message
163+
end
164+
147165
def test_model_to_string
148166
assert booster.model_to_string
149167
end

0 commit comments

Comments
 (0)