Skip to content

Commit 88a8f91

Browse files
authored
Enable option zipmap for LGBM converter (fix issue #451) (#452)
* Enable option zipmap for LGBM converter * add one more unittest * support booster
1 parent 331df2e commit 88a8f91

File tree

6 files changed

+111
-28
lines changed

6 files changed

+111
-28
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,13 @@ def _parse_lightgbm_simple_model(scope, model, inputs):
8787
return this_operator.outputs
8888

8989

90-
def _parse_sklearn_classifier(scope, model, inputs):
90+
def _parse_sklearn_classifier(scope, model, inputs, zipmap=True):
9191
probability_tensor = _parse_lightgbm_simple_model(
9292
scope, model, inputs)
9393
this_operator = scope.declare_local_operator('LgbmZipMap')
9494
this_operator.inputs = probability_tensor
95+
this_operator.zipmap = zipmap
96+
9597
classes = model.classes_
9698
label_type = Int64Type()
9799

@@ -116,33 +118,39 @@ def _parse_sklearn_classifier(scope, model, inputs):
116118
label_type = StringType()
117119

118120
output_label = scope.declare_local_variable('label', label_type)
119-
output_probability = scope.declare_local_variable(
120-
'probabilities',
121-
SequenceType(DictionaryType(label_type, FloatTensorType())))
121+
if zipmap:
122+
output_probability = scope.declare_local_variable(
123+
'probabilities',
124+
SequenceType(DictionaryType(label_type, FloatTensorType())))
125+
else:
126+
output_probability = scope.declare_local_variable(
127+
'probabilities', FloatTensorType())
122128
this_operator.outputs.append(output_label)
123129
this_operator.outputs.append(output_probability)
124130
return this_operator.outputs
125131

126132

127-
def _parse_lightgbm(scope, model, inputs):
133+
def _parse_lightgbm(scope, model, inputs, zipmap=True):
128134
'''
129135
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input
130136
model's type.
131137
:param scope: Scope object
132138
:param model: A lightgbm object
133139
:param inputs: A list of variables
140+
:param zipmap: add operator ZipMap after operator TreeEnsembleClassifier
134141
:return: The output variables produced by the input model
135142
'''
136143
if isinstance(model, LGBMClassifier):
137-
return _parse_sklearn_classifier(scope, model, inputs)
144+
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
138145
if (isinstance(model, WrappedBooster) and
139146
model.operator_name == 'LgbmClassifier'):
140-
return _parse_sklearn_classifier(scope, model, inputs)
147+
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
141148
return _parse_lightgbm_simple_model(scope, model, inputs)
142149

143150

144151
def parse_lightgbm(model, initial_types=None, target_opset=None,
145-
custom_conversion_functions=None, custom_shape_calculators=None):
152+
custom_conversion_functions=None, custom_shape_calculators=None,
153+
zipmap=True):
146154
raw_model_container = LightGbmModelContainer(model)
147155
topology = Topology(raw_model_container, default_batch_size='None',
148156
initial_types=initial_types, target_opset=target_opset,
@@ -157,9 +165,9 @@ def parse_lightgbm(model, initial_types=None, target_opset=None,
157165
for variable in inputs:
158166
raw_model_container.add_input(variable)
159167

160-
outputs = _parse_lightgbm(scope, model, inputs)
168+
outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap)
161169

162170
for variable in outputs:
163171
raw_model_container.add_output(variable)
164172

165-
return topology
173+
return topology

onnxmltools/convert/lightgbm/convert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def convert(model, name=None, initial_types=None, doc_string='', target_opset=None,
1717
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
18-
custom_shape_calculators=None, without_onnx_ml=False):
18+
custom_shape_calculators=None, without_onnx_ml=False, zipmap=True):
1919
'''
2020
This function produces an equivalent ONNX model of the given lightgbm model.
2121
The supported lightgbm modules are listed below.
@@ -34,6 +34,7 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
3434
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
3535
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
3636
:param without_onnx_ml: whether to generate a model composed by ONNX operators only, or to allow the converter
37+
:param zipmap: remove operator ZipMap from the ONNX graph
3738
to use ONNX-ML operators as well.
3839
:return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model
3940
'''
@@ -50,7 +51,8 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
5051
name = str(uuid4().hex)
5152

5253
target_opset = target_opset if target_opset else get_maximum_opset_supported()
53-
topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators)
54+
topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions,
55+
custom_shape_calculators, zipmap=zipmap)
5456
topology.compile()
5557
onnx_ml_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)
5658

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numbers
55
import numpy as np
66
from collections import Counter
7-
from ...common._apply_operation import apply_div, apply_reshape, apply_sub, apply_cast, apply_identity
7+
from ...common._apply_operation import (
8+
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
89
from ...common._registration import register_converter
910
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
1011
from ....proto import onnx_proto
@@ -453,23 +454,32 @@ def str2number(val):
453454

454455
def convert_lgbm_zipmap(scope, operator, container):
455456
zipmap_attrs = {'name': scope.get_unique_operator_name('ZipMap')}
456-
to_type = onnx_proto.TensorProto.INT64
457-
458457
if hasattr(operator, 'classlabels_int64s'):
459458
zipmap_attrs['classlabels_int64s'] = operator.classlabels_int64s
459+
to_type = onnx_proto.TensorProto.INT64
460460
elif hasattr(operator, 'classlabels_strings'):
461461
zipmap_attrs['classlabels_strings'] = operator.classlabels_strings
462462
to_type = onnx_proto.TensorProto.STRING
463-
463+
else:
464+
raise RuntimeError("Unknown class type.")
464465
if to_type == onnx_proto.TensorProto.STRING:
465466
apply_identity(scope, operator.inputs[0].full_name,
466467
operator.outputs[0].full_name, container)
467468
else:
468469
apply_cast(scope, operator.inputs[0].full_name,
469470
operator.outputs[0].full_name, container, to=to_type)
470-
container.add_node('ZipMap', operator.inputs[1].full_name,
471-
operator.outputs[1].full_name,
472-
op_domain='ai.onnx.ml', **zipmap_attrs)
471+
472+
if operator.zipmap:
473+
container.add_node('ZipMap', operator.inputs[1].full_name,
474+
operator.outputs[1].full_name,
475+
op_domain='ai.onnx.ml', **zipmap_attrs)
476+
else:
477+
# This should be apply_identity but optimization fails in
478+
# onnxconverter-common when trying to remove identity nodes.
479+
apply_clip(scope, operator.inputs[1].full_name,
480+
operator.outputs[1].full_name, container,
481+
min=np.array([0], dtype=np.float32),
482+
max=np.array([1], dtype=np.float32))
473483

474484

475485
register_converter('LgbmClassifier', convert_lightgbm)

onnxmltools/convert/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,14 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target
5252

5353
def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None,
5454
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
55-
custom_shape_calculators=None, without_onnx_ml=False):
55+
custom_shape_calculators=None, without_onnx_ml=False, zipmap=True):
5656
if not utils.lightgbm_installed():
5757
raise RuntimeError('lightgbm is not installed. Please install lightgbm to use this feature.')
5858

5959
from .lightgbm.convert import convert
6060
return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx,
61-
custom_conversion_functions, custom_shape_calculators, without_onnx_ml)
61+
custom_conversion_functions, custom_shape_calculators, without_onnx_ml,
62+
zipmap=zipmap)
6263

6364

6465
def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None,

onnxmltools/utils/tests_helper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
191191
return names
192192

193193

194-
def convert_model(model, name, input_types, without_onnx_ml=False):
194+
def convert_model(model, name, input_types, without_onnx_ml=False, **kwargs):
195195
"""
196196
Runs the appropriate conversion method.
197197
@@ -201,26 +201,26 @@ def convert_model(model, name, input_types, without_onnx_ml=False):
201201
from sklearn.base import BaseEstimator
202202
if model.__class__.__name__.startswith("LGBM"):
203203
from onnxmltools.convert import convert_lightgbm
204-
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
204+
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm"
205205
elif model.__class__.__name__.startswith("XGB"):
206206
from onnxmltools.convert import convert_xgboost
207-
model, prefix = convert_xgboost(model, name, input_types), "XGB"
207+
model, prefix = convert_xgboost(model, name, input_types, **kwargs), "XGB"
208208
elif model.__class__.__name__ == 'Booster':
209209
import lightgbm
210210
if isinstance(model, lightgbm.Booster):
211211
from onnxmltools.convert import convert_lightgbm
212-
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
212+
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml, **kwargs), "LightGbm"
213213
else:
214214
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
215215
elif model.__class__.__name__.startswith("CatBoost"):
216216
from onnxmltools.convert import convert_catboost
217-
model, prefix = convert_catboost(model, name, input_types), "CatBoost"
217+
model, prefix = convert_catboost(model, name, input_types, **kwargs), "CatBoost"
218218
elif isinstance(model, BaseEstimator):
219219
from onnxmltools.convert import convert_sklearn
220-
model, prefix = convert_sklearn(model, name, input_types), "Sklearn"
220+
model, prefix = convert_sklearn(model, name, input_types, **kwargs), "Sklearn"
221221
else:
222222
from onnxmltools.convert import convert_coreml
223-
model, prefix = convert_coreml(model, name, input_types), "Cml"
223+
model, prefix = convert_coreml(model, name, input_types, **kwargs), "Cml"
224224
if model is None:
225225
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
226226
return model, prefix

tests/lightgbm/test_LightGbmTreeEnsembleConverters.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
import lightgbm
77
import numpy
8+
from numpy.testing import assert_almost_equal
89
from lightgbm import LGBMClassifier, LGBMRegressor
910
import onnxruntime
1011
from onnxmltools.convert.common.utils import hummingbird_installed
1112
from onnxmltools.convert.common.data_types import FloatTensorType
13+
from onnxmltools.convert import convert_lightgbm
1214
from onnxmltools.utils import dump_data_and_model
1315
from onnxmltools.utils import dump_binary_classification, dump_multiple_classification
1416
from onnxmltools.utils import dump_single_regression
@@ -32,6 +34,50 @@ def test_lightgbm_classifier_zipmap(self):
3234
model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))])
3335
assert "zipmap" in str(onx).lower()
3436

37+
def test_lightgbm_classifier_nozipmap(self):
38+
X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]]
39+
X = numpy.array(X, dtype=numpy.float32)
40+
y = [0, 1, 0, 1, 1, 0]
41+
model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2)
42+
model.fit(X, y)
43+
onx = convert_model(
44+
model, 'dummy', input_types=[('X', FloatTensorType([None, X.shape[1]]))],
45+
zipmap=False)
46+
assert "zipmap" not in str(onx).lower()
47+
onxs = onx[0].SerializeToString()
48+
try:
49+
sess = onnxruntime.InferenceSession(onxs)
50+
except Exception as e:
51+
raise AssertionError(
52+
"Model cannot be loaded by onnxruntime due to %r\n%s." % (
53+
e, onx[0]))
54+
exp = model.predict(X), model.predict_proba(X)
55+
got = sess.run(None, {'X': X})
56+
assert_almost_equal(exp[0], got[0])
57+
assert_almost_equal(exp[1], got[1])
58+
59+
def test_lightgbm_classifier_nozipmap2(self):
60+
X = [[0, 1], [1, 1], [2, 0], [1, 2], [1, 5], [6, 2]]
61+
X = numpy.array(X, dtype=numpy.float32)
62+
y = [0, 1, 0, 1, 1, 0]
63+
model = LGBMClassifier(n_estimators=3, min_child_samples=1, max_depth=2)
64+
model.fit(X, y)
65+
onx = convert_lightgbm(
66+
model, 'dummy', initial_types=[('X', FloatTensorType([None, X.shape[1]]))],
67+
zipmap=False)
68+
assert "zipmap" not in str(onx).lower()
69+
onxs = onx.SerializeToString()
70+
try:
71+
sess = onnxruntime.InferenceSession(onxs)
72+
except Exception as e:
73+
raise AssertionError(
74+
"Model cannot be loaded by onnxruntime due to %r\n%s." % (
75+
e, onx[0]))
76+
exp = model.predict(X), model.predict_proba(X)
77+
got = sess.run(None, {'X': X})
78+
assert_almost_equal(exp[0], got[0])
79+
assert_almost_equal(exp[1], got[1])
80+
3581
def test_lightgbm_regressor(self):
3682
model = LGBMRegressor(n_estimators=3, min_child_samples=1)
3783
dump_single_regression(model)
@@ -58,6 +104,22 @@ def test_lightgbm_booster_classifier(self):
58104
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
59105
basename=prefix + "BoosterBin" + model.__class__.__name__)
60106

107+
def test_lightgbm_booster_classifier_nozipmap(self):
108+
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
109+
X = numpy.array(X, dtype=numpy.float32)
110+
y = [0, 1, 0, 1]
111+
data = lightgbm.Dataset(X, label=y)
112+
model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary',
113+
'n_estimators': 3, 'min_child_samples': 1},
114+
data)
115+
model_onnx, prefix = convert_model(model, 'tree-based classifier',
116+
[('input', FloatTensorType([None, 2]))],
117+
zipmap=False)
118+
assert "zipmap" not in str(model_onnx).lower()
119+
dump_data_and_model(X, model, model_onnx,
120+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
121+
basename=prefix + "BoosterBin" + model.__class__.__name__)
122+
61123
def test_lightgbm_booster_classifier_zipmap(self):
62124
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
63125
X = numpy.array(X, dtype=numpy.float32)

0 commit comments

Comments
 (0)