Skip to content

Commit cf660f0

Browse files
authored
support missing values in lgbm regressor correctly (#488)
Signed-off-by: Jan-Benedikt Jagusch <[email protected]>
1 parent b5dc2f6 commit cf660f0

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def _parse_tree_structure(tree_id, class_id, learning_rate,
101101
attrs['nodes_truenodeids'].append(left_id)
102102
attrs['nodes_falsenodeids'].append(right_id)
103103
if tree_structure['default_left']:
104-
attrs['nodes_missing_value_tracks_true'].append(1)
104+
if tree_structure["missing_type"] == 'None' and float(tree_structure['threshold']) < 0.0:
105+
attrs['nodes_missing_value_tracks_true'].append(0)
106+
else:
107+
attrs['nodes_missing_value_tracks_true'].append(1)
105108
else:
106109
attrs['nodes_missing_value_tracks_true'].append(0)
107110
attrs['nodes_hitrates'].append(1.)
@@ -166,7 +169,10 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool,
166169
attrs['nodes_truenodeids'].append(left_id)
167170
attrs['nodes_falsenodeids'].append(right_id)
168171
if node['default_left']:
169-
attrs['nodes_missing_value_tracks_true'].append(1)
172+
if node['missing_type'] == 'None' and float(node['threshold']) < 0.0:
173+
attrs['nodes_missing_value_tracks_true'].append(0)
174+
else:
175+
attrs['nodes_missing_value_tracks_true'].append(1)
170176
else:
171177
attrs['nodes_missing_value_tracks_true'].append(0)
172178
attrs['nodes_hitrates'].append(1.)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
3+
import numpy as np
4+
from onnx import ModelProto
5+
from onnxconverter_common.data_types import FloatTensorType
6+
from onnxmltools import convert_lightgbm
7+
from onnxruntime import InferenceSession
8+
9+
from lightgbm import LGBMRegressor
10+
11+
_N_DECIMALS=5
12+
_FRAC=0.9999
13+
14+
_y = np.array([0, 0, 1, 1, 1])
15+
_X_train = np.array([[1.0, 0.0], [1.0, -1.0], [1.0, -1.0], [2.0, -1.0], [2.0, -1.0]], dtype=np.float32)
16+
_X_test = np.array([[1.0, np.nan]], dtype=np.float32)
17+
18+
_INITIAL_TYPES = [("input", FloatTensorType([None, _X_train.shape[1]]))]
19+
20+
21+
class TestMissingValues(unittest.TestCase):
22+
23+
@staticmethod
24+
def _predict_with_onnx(model: ModelProto, X: np.array) -> np.array:
25+
session = InferenceSession(model.SerializeToString())
26+
output_names = [s_output.name for s_output in session.get_outputs()]
27+
input_names = [s_input.name for s_input in session.get_inputs()]
28+
if len(input_names) > 1:
29+
raise RuntimeError(f"Test expects one input. Found multiple inputs: {input_names}.")
30+
input_name = input_names[0]
31+
return session.run(output_names, {input_name: X})[0][:, 0]
32+
33+
@staticmethod
34+
def _assert_almost_equal(actual: np.array, desired: np.array, decimal: int=7, frac: float=1.0):
35+
"""
36+
Assert that almost all rows in actual and desired are almost equal to each other.
37+
Similar to np.testing.assert_almost_equal but allows to define a fraction of rows to be almost
38+
equal instead of expecting all rows to be almost equal.
39+
"""
40+
assert 0 <= frac <= 1, "frac must be in range(0, 1)."
41+
success_abs = (abs(actual - desired) <= (10 ** -decimal)).sum()
42+
success_rel = success_abs / len(actual)
43+
assert success_rel >= frac, f"Only {success_abs} out of {len(actual)} rows are almost equal to {decimal} decimals."
44+
45+
46+
def test_missing_values(self):
47+
"""
48+
Test that an ONNX model for a LGBM regressor that was trained without having seen missing values
49+
correctly predicts rows that contain missing values.
50+
"""
51+
regressor = LGBMRegressor(
52+
objective="regression",
53+
min_data_in_bin=1,
54+
min_data_in_leaf=1,
55+
n_estimators=1,
56+
learning_rate=1,
57+
)
58+
regressor.fit(_X_train, _y)
59+
regressor_onnx: ModelProto = convert_lightgbm(regressor, initial_types=_INITIAL_TYPES)
60+
y_pred = regressor.predict(_X_test)
61+
y_pred_onnx = self._predict_with_onnx(regressor_onnx, _X_test)
62+
self._assert_almost_equal(
63+
y_pred,
64+
y_pred_onnx,
65+
decimal=_N_DECIMALS,
66+
frac=_FRAC,
67+
)

0 commit comments

Comments
 (0)