Skip to content

Commit 30d5fcf

Browse files
authored
Fixes #234, fix lightgbm converter when a root is a leave (#235)
* Fixes #234, fixes lightgbm when the root of a tree is a leave
1 parent ead6925 commit 30d5fcf

File tree

5 files changed

+46
-22
lines changed

5 files changed

+46
-22
lines changed

.azure-pipelines/win32-conda-CI.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ jobs:
1919
Python35:
2020
python.version: '3.5'
2121
ONNX_PATH: onnx==1.2.3
22-
KERAS: keras==2.1.6
2322
COREML_PATH: https://github.com/apple/coremltools/archive/v2.0.zip
2423

2524
Python36:
2625
python.version: '3.6'
2726
ONNX_PATH: onnx==1.3.0
28-
KERAS: keras
2927
COREML_PATH: git+https://github.com/apple/coremltools
3028

3129
maxParallel: 3
@@ -49,7 +47,6 @@ jobs:
4947
pip install %COREML_PATH% %ONNX_PATH%
5048
pip install -r requirements-dev.txt
5149
echo Test onnxruntime installation... && python -c "import onnxruntime"
52-
pip install %KERAS%
5350
REM install libsvm from github
5451
git clone --recursive https://github.com/cjlin1/libsvm libsvm
5552
copy libsvm\windows\*.dll libsvm\python

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def _parse_tree_structure(tree_id, class_id, learning_rate, tree_structure, attr
4242
node_id_pool = set()
4343

4444
node_id = _create_node_id(node_id_pool)
45+
46+
# The root node is a leaf node.
47+
if not 'left_child' in tree_structure or not 'right_child' in tree_structure:
48+
_parse_node(tree_id, class_id, node_id, node_id_pool, learning_rate, tree_structure, attrs)
49+
return
50+
4551
left_id = _create_node_id(node_id_pool)
4652
right_id = _create_node_id(node_id_pool)
4753

onnxmltools/utils/tests_helper.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
3434
:param backend: backend used to compare expected output and runtime output.
3535
Two options are currently supported: None for no test,
3636
`'onnxruntime'` to use module *onnxruntime*.
37-
:param context: used if the model contains a custom operator such
38-
as a custom Keras function...
37+
:param context: used if the model contains a custom operator
3938
:param allow_failure: None to raise an exception if comparison fails
4039
for the backends, otherwise a string which is then evaluated to check
4140
whether or not the test can fail, example:
@@ -108,15 +107,10 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
108107
with open(dest, "wb") as f:
109108
pickle.dump(data, f)
110109

111-
if hasattr(model, 'save'):
112-
dest = os.path.join(folder, basename + ".model.keras")
113-
names.append(dest)
114-
model.save(dest)
115-
else:
116-
dest = os.path.join(folder, basename + ".model.pkl")
117-
names.append(dest)
118-
with open(dest, "wb") as f:
119-
pickle.dump(model, f)
110+
dest = os.path.join(folder, basename + ".model.pkl")
111+
names.append(dest)
112+
with open(dest, "wb") as f:
113+
pickle.dump(model, f)
120114

121115
if onnx is None:
122116
array = numpy.array(data)
@@ -168,7 +162,7 @@ def convert_model(model, name, input_types):
168162
"""
169163
Runs the appropriate conversion method.
170164
171-
:param model: model, *scikit-learn*, *keras*, or *coremltools* object
165+
:param model: model
172166
:return: *onnx* model
173167
"""
174168
from sklearn.base import BaseEstimator
@@ -179,13 +173,8 @@ def convert_model(model, name, input_types):
179173
from onnxmltools.convert import convert_sklearn
180174
model, prefix = convert_sklearn(model, name, input_types), "Sklearn"
181175
else:
182-
from keras.models import Model
183-
if isinstance(model, Model):
184-
from onnxmltools.convert import convert_keras
185-
model, prefix = convert_keras(model, name, input_types), "Keras"
186-
else:
187-
from onnxmltools.convert import convert_coreml
188-
model, prefix = convert_coreml(model, name, input_types), "Cml"
176+
from onnxmltools.convert import convert_coreml
177+
model, prefix = convert_coreml(model, name, input_types), "Cml"
189178
if model is None:
190179
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
191180
return model, prefix

tests/lightgbm/example.pkl

93.6 KB
Binary file not shown.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
import sys
7+
import unittest
8+
import numpy
9+
import pickle
10+
import os
11+
from onnxmltools import convert_lightgbm
12+
from onnxmltools.convert.common.data_types import FloatTensorType
13+
from onnxmltools.utils import dump_data_and_model
14+
15+
16+
class TestLightGbmTreeEnsembleModelsPkl(unittest.TestCase):
17+
18+
@unittest.skipIf(sys.version_info[0] == 2, reason="pickled with Python 3, cannot unpickle with 2")
19+
@unittest.skipIf(sys.platform.startswith('win'), reason="pickled on linux, may not work on windows")
20+
def test_root_leave(self):
21+
this = os.path.abspath(os.path.dirname(__file__))
22+
for name in ["example.pkl"]:
23+
with open(os.path.join(this, name), "rb") as f:
24+
model = pickle.load(f)
25+
X = [[0., 1.], [1., 1.], [2., 0.]]
26+
X = numpy.array(X, dtype=numpy.float32)
27+
model_onnx = convert_lightgbm(model.steps[1][1], 'pkl1', [('input', FloatTensorType([1, X.shape[1]]))])
28+
dump_data_and_model(X, model.steps[1][1], model_onnx, basename="LightGbmPkl1")
29+
30+
31+
if __name__ == "__main__":
32+
unittest.main()

0 commit comments

Comments
 (0)