Skip to content

Commit b5dc2f6

Browse files
authored
Support gamma objective in LGBMRegressor (#484)
* support objective=gamma in lgbm regressor * add unit test for objective functions * add test case for objective='regression' * set fraction of rows that need to be almost equal to 0.9999 * skip objective test if onnxmltools version is lower than 1.3 Signed-off-by: Jan-Benedikt Jagusch <[email protected]>
1 parent 0795ad0 commit b5dc2f6

File tree

3 files changed

+91
-3
lines changed

3 files changed

+91
-3
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def __init__(self, booster):
2727
if (_model_dict['objective'].startswith('binary') or
2828
_model_dict['objective'].startswith('multiclass')):
2929
self.operator_name = 'LgbmClassifier'
30-
elif (_model_dict['objective'].startswith('regression') or
31-
_model_dict['objective'].startswith('poisson')):
30+
elif _model_dict['objective'].startswith(('regression', 'poisson', 'gamma')):
3231
self.operator_name = 'LgbmRegressor'
3332
else:
3433
# Other objectives are not supported.

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def convert_lightgbm(scope, operator, container):
234234
n_classes = 1 # Regressor has only one output variable
235235
attrs['post_transform'] = 'NONE'
236236
attrs['n_targets'] = n_classes
237-
elif gbm_text['objective'].startswith('poisson'):
237+
elif gbm_text['objective'].startswith(('poisson', 'gamma')):
238238
n_classes = 1 # Regressor has only one output variable
239239
attrs['n_targets'] = n_classes
240240
# 'Exp' is not a supported post_transform value in the ONNX spec yet,
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import unittest
2+
from typing import Dict, List, Tuple
3+
4+
import numpy as np
5+
import onnxruntime
6+
import pandas as pd
7+
from onnx import ModelProto
8+
from onnxconverter_common.data_types import DoubleTensorType, TensorType
9+
from onnxmltools import convert_lightgbm
10+
from onnxruntime import InferenceSession
11+
from pandas.core.frame import DataFrame
12+
13+
from lightgbm import LGBMRegressor
14+
15+
_N_ROWS=10_000
16+
_N_COLS=10
17+
_N_DECIMALS=5
18+
_FRAC = 0.9999
19+
20+
_X = pd.DataFrame(np.random.random(size=(_N_ROWS, _N_COLS)))
21+
_Y = pd.Series(np.random.random(size=_N_ROWS))
22+
23+
_DTYPE_MAP: Dict[str, TensorType] = {
24+
"float64": DoubleTensorType,
25+
}
26+
27+
28+
class ObjectiveTest(unittest.TestCase):
29+
30+
_objectives: Tuple[str] = (
31+
"regression",
32+
"poisson",
33+
"gamma",
34+
)
35+
36+
@staticmethod
37+
def _calc_initial_types(X: DataFrame) -> List[Tuple[str, TensorType]]:
38+
dtypes = set(str(dtype) for dtype in X.dtypes)
39+
if len(dtypes) > 1:
40+
raise RuntimeError(f"Test expects homogenous input matrix. Found multiple dtypes: {dtypes}.")
41+
dtype = dtypes.pop()
42+
tensor_type = _DTYPE_MAP[dtype]
43+
return [("input", tensor_type(X.shape))]
44+
45+
@staticmethod
46+
def _predict_with_onnx(model: ModelProto, X: DataFrame) -> np.array:
47+
session = InferenceSession(model.SerializeToString())
48+
output_names = [s_output.name for s_output in session.get_outputs()]
49+
input_names = [s_input.name for s_input in session.get_inputs()]
50+
if len(input_names) > 1:
51+
raise RuntimeError(f"Test expects one input. Found multiple inputs: {input_names}.")
52+
input_name = input_names[0]
53+
return session.run(output_names, {input_name: X.values})[0][:, 0]
54+
55+
@staticmethod
56+
def _assert_almost_equal(actual: np.array, desired: np.array, decimal: int=7, frac: float=1.0):
57+
"""
58+
Assert that almost all rows in actual and desired are almost equal to each other.
59+
60+
Similar to np.testing.assert_almost_equal but allows to define a fraction of rows to be almost
61+
equal instead of expecting all rows to be almost equal.
62+
"""
63+
assert 0 <= frac <= 1, "frac must be in range(0, 1)."
64+
success_abs = (abs(actual - desired) <= (10 ** -decimal)).sum()
65+
success_rel = success_abs / len(actual)
66+
assert success_rel >= frac, f"Only {success_abs} out of {len(actual)} rows are almost equal to {decimal} decimals."
67+
68+
@unittest.skipIf(tuple(int(ver) for ver in onnxruntime.__version__.split(".")) < (1, 3), "not supported in this library version")
69+
def test_objective(self):
70+
"""
71+
Test if a LGBMRegressor a with certain objective (e.g. 'poisson') can be converted to ONNX
72+
and whether the ONNX graph and the original model produce almost equal predictions.
73+
74+
Note that this tests is a bit flaky because of precision differences with ONNX and LightGBM
75+
and therefore sometimes fails randomly. In these cases, a retry should resolve the issue.
76+
"""
77+
for objective in self._objectives:
78+
with self.subTest(X=_X, objective=objective):
79+
regressor = LGBMRegressor(objective=objective)
80+
regressor.fit(_X, _Y)
81+
regressor_onnx: ModelProto = convert_lightgbm(regressor, initial_types=self._calc_initial_types(_X))
82+
y_pred = regressor.predict(_X)
83+
y_pred_onnx = self._predict_with_onnx(regressor_onnx, _X)
84+
self._assert_almost_equal(
85+
y_pred,
86+
y_pred_onnx,
87+
decimal=_N_DECIMALS,
88+
frac=_FRAC,
89+
)

0 commit comments

Comments
 (0)