Skip to content

Commit 1506387

Browse files
ankanenunosilva800
andcommitted
Added support for pandas_categorical to predict method - resolves #8
Co-authored-by: Nuno Silva <[email protected]>
1 parent 2f8bc98 commit 1506387

File tree

8 files changed

+343
-0
lines changed

8 files changed

+343
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.4.0 (unreleased)
22

33
- Added support for different prediction types
4+
- Added support for `pandas_categorical` to `predict` method
45
- Added support for hashes and Rover data frames to `predict` method
56
- Added support for hashes to `Dataset`
67
- Added `importance_type` option to `dump_model`, `model_to_string`, and `save_model` methods

lib/lightgbm.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# dependencies
22
require "ffi"
33

4+
# stdlib
5+
require "json"
6+
47
# modules
58
require_relative "lightgbm/utils"
69
require_relative "lightgbm/booster"

lib/lightgbm/booster.rb

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
1212
create_handle do |handle|
1313
safe_call FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, handle)
1414
end
15+
@pandas_categorical = load_pandas_categorical(file_name: model_file)
1516
else
1617
params ||= {}
1718
set_verbosity(params)
@@ -96,6 +97,7 @@ def model_from_string(model_str)
9697
create_handle do |handle|
9798
safe_call FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, handle)
9899
end
100+
@pandas_categorical = load_pandas_categorical(model_str: model_str)
99101
@cached_feature_name = nil
100102
self
101103
end
@@ -235,5 +237,48 @@ def feature_importance_type_mapper(importance_type)
235237
-1
236238
end
237239
end
240+
241+
def load_pandas_categorical(file_name: nil, model_str: nil)
242+
pandas_key = "pandas_categorical:"
243+
offset = -pandas_key.length
244+
if !file_name.nil?
245+
max_offset = -File.size(file_name)
246+
lines = []
247+
File.open(file_name, "rb") do |f|
248+
loop do
249+
offset = [offset, max_offset].max
250+
f.seek(offset, IO::SEEK_END)
251+
lines = f.readlines
252+
if lines.length >= 2
253+
break
254+
end
255+
offset *= 2
256+
end
257+
end
258+
last_line = lines[-1].strip
259+
if !last_line.start_with?(pandas_key)
260+
last_line = lines[-2].strip
261+
end
262+
elsif !model_str.nil?
263+
idx = model_str[..offset].rindex("\n")
264+
last_line = model_str[idx..].strip
265+
end
266+
if last_line.start_with?(pandas_key)
267+
JSON.parse(last_line[pandas_key.length..])
268+
end
269+
end
270+
271+
def loaded_param
272+
buffer_len = 1 << 20
273+
out_len = ::FFI::MemoryPointer.new(:int64)
274+
out_str = ::FFI::MemoryPointer.new(:char, buffer_len)
275+
safe_call FFI.LGBM_BoosterGetLoadedParam(@handle, buffer_len, out_len, out_str)
276+
actual_len = out_len.read_int64
277+
if actual_len > buffer_len
278+
out_str = ::FFI::MemoryPointer.new(:char, actual_len)
279+
safe_call FFI.LGBM_BoosterGetLoadedParam(@handle, actual_len, out_len, out_str)
280+
end
281+
JSON.parse(out_str.read_string)
282+
end
238283
end
239284
end

lib/lightgbm/ffi.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ module FFI
4949
attach_function :LGBM_BoosterCreate, %i[pointer string pointer], :int
5050
attach_function :LGBM_BoosterCreateFromModelfile, %i[string pointer pointer], :int
5151
attach_function :LGBM_BoosterLoadModelFromString, %i[string pointer pointer], :int
52+
attach_function :LGBM_BoosterGetLoadedParam, %i[pointer int64 pointer pointer], :int
5253
attach_function :LGBM_BoosterFree, %i[pointer], :int
5354
attach_function :LGBM_BoosterAddValidData, %i[pointer pointer], :int
5455
attach_function :LGBM_BoosterGetNumClasses, %i[pointer pointer], :int

lib/lightgbm/inner_predictor.rb

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class InnerPredictor
66

77
def initialize(booster, pred_parameter)
88
@handle = booster.instance_variable_get(:@handle)
9+
@pandas_categorical = booster.instance_variable_get(:@pandas_categorical)
910
@pred_parameter = params_str(pred_parameter)
1011

1112
# keep booster for cached_feature_name
@@ -50,6 +51,15 @@ def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_
5051
singular = !data.first.is_a?(Array)
5152
data = [data] if singular
5253
check_2d_array(data)
54+
data = data.map(&:dup) if @pandas_categorical&.any?
55+
end
56+
57+
if @pandas_categorical&.any?
58+
apply_pandas_categorical(
59+
data,
60+
@booster.send(:loaded_param)["categorical_feature"],
61+
@pandas_categorical
62+
)
5363
end
5464

5565
preds, nrow =
@@ -127,5 +137,16 @@ def sorted_feature_values(input_hash)
127137
def cached_feature_name
128138
@booster.send(:cached_feature_name)
129139
end
140+
141+
def apply_pandas_categorical(data, categorical_feature, pandas_categorical)
142+
(categorical_feature || []).each_with_index do |cf, i|
143+
cat_codes = pandas_categorical[i].map.with_index.to_h
144+
# TODO confirm column is categorical
145+
data.each do |r|
146+
# TODO decide how to handle missing values
147+
r[cf] = cat_codes.fetch(r[cf])
148+
end
149+
end
150+
end
130151
end
131152
end

test/booster_test.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def test_predict_pred_contrib
130130
assert_elements_in_delta expected[0], y_pred
131131
end
132132

133+
def test_predict_pandas_categorical_model_file
134+
x_test = [[3.7, 1.2, 7.2, "cat9"], [7.5, 0.5, 7.9, "cat0"]]
135+
booster = LightGBM::Booster.new(model_file: "test/support/categorical.txt")
136+
y_pred = booster.predict(x_test)
137+
assert_elements_in_delta [0.996415541144579, 1.0809369939979934], y_pred.first(2)
138+
end
139+
140+
def test_predict_pandas_categorical_model_str
141+
x_test = [[3.7, 1.2, 7.2, "cat9"], [7.5, 0.5, 7.9, "cat0"]]
142+
booster = LightGBM::Booster.new(model_str: File.read("test/support/categorical.txt"))
143+
y_pred = booster.predict(x_test)
144+
assert_elements_in_delta [0.996415541144579, 1.0809369939979934], y_pred.first(2)
145+
end
146+
133147
def test_model_to_string
134148
assert booster.model_to_string
135149
end

test/support/categorical.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import lightgbm as lgb
2+
import pandas as pd
3+
4+
df = pd.read_csv('test/support/data.csv')
5+
df['x3'] = ('cat' + df['x3'].astype(str)).astype('category')
6+
7+
X = df.drop(columns=['y'])
8+
y = df['y']
9+
10+
X_train = X[:300]
11+
y_train = y[:300]
12+
X_test = X[300:]
13+
y_test = y[300:]
14+
15+
train_data = lgb.Dataset(X_train, label=y_train)
16+
bst = lgb.train({}, train_data, num_boost_round=5)
17+
bst.save_model('test/support/categorical.txt')
18+
19+
bst = lgb.Booster(model_file='test/support/categorical.txt')
20+
print('x', X_train[:2].to_numpy().tolist())
21+
print('predict', bst.predict(X_train)[:2].tolist())

0 commit comments

Comments
 (0)