Skip to content

Commit eef63ee

Browse files
xadupreWenbing Li
authored andcommitted
Fixes #18, add converters for xgboost (#192)
* remove unnecessary print, add quote around filenames in some places * replaces as_matrix by values (pandas warnings) * changes variable name to avoid getting warnings about invalid names * better consistency for converted, allows targetted onnx version to be None * Revert "better consistency for converted, allows targetted onnx version to be None" This reverts commit e257ca1. * handle the comparison of ONNX versions in only one place * fix bug with OneHotEncoder and scikit-learn 0.20 * release the constraint on scikit-learn (0.20.0 allowed) * fix one type issue for Python 2.7 * add documentation to compare_strict_version * Fixes #151, BernouilliNB converter * Removes unused nodes in graph * Adresses issue #143, enables build with keras 2.1.2 * Revert modifications due to a wrong merge * update keras version * Disable test on keras/mobilenet as it does not work * add unit test for xception (failing) * remove duplicate install * skip unit test if not installed (tensorflow still not available on python 3.7) * Fix when keras is not available * Fix missing import * Update test_single_operator_with_cntk_backend.py * Set up CI with Azure Pipelines * Update azure pipeline * Skip a unit test if tensorflow is not installed * merge * missing import * Revert "Merge branch 'master' of https://github.com/onnx/onnxmltools" This reverts commit 178e763, reversing changes made to 1a617ef. * revert changes * Revert changes * \r * \r * first step in the migration of xgboost code * XGBoost regression works * Finalize xgboost converter * Update README.md * Add function has_tensorflow * Update test_single_operator_with_cntk_backend.py * better desgin for a unit test * update xgboost classifier * Delete test_keras_xception.py * Delete requirements-deep.txt * Delete test_keras_modebilenetv2.py * less spaces * lower precision for xgboost comparison tests * disable xgboost testing on python 2
1 parent 30d5fcf commit eef63ee

24 files changed

+610
-11
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ This package relies on ONNX, NumPy, and ProtoBuf. If you are converting a model
3232
2. CoreMLTools
3333
3. Keras (version 2.0.8 or higher) with the corresponding Tensorflow version
3434
4. LightGBM (scikit-learn interface)
35+
5. XGBoost (scikit-learn interface)
36+
6. libsvm
3537

3638
# Examples
3739
If you want the converted ONNX model to be compatible with a certain ONNX version, please specify the target_opset parameter upon invoking the convert function. The following Keras model conversion example demonstrates this below. You can identify the mapping from ONNX Operator Sets (referred to as opsets) to ONNX releases in the [versioning documentation](https://github.com/onnx/onnx/blob/master/docs/Versioning.md#released-versions).

onnxmltools/convert/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@
99
from .main import convert_libsvm
1010
from .main import convert_lightgbm
1111
from .main import convert_sklearn
12+
from .main import convert_xgboost
13+

onnxmltools/convert/common/_container.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class LightGbmModelContainer(CommonSklearnModelContainer):
9292
pass
9393

9494

95+
class XGBoostModelContainer(CommonSklearnModelContainer):
96+
pass
97+
98+
9599
class KerasModelContainer(RawModelContainer):
96100

97101
def __init__(self, keras_model):

onnxmltools/convert/common/data_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def __init__(self, shape=None, doc_string=''):
1515

1616
def to_onnx_type(self):
1717
raise NotImplementedError()
18+
19+
def __repr__(self):
20+
name = self.__class__.__name__
21+
return "{}({}, '{}')".format(name, self.shape, self.doc_string)
1822

1923

2024
class Int64Type(DataType):

onnxmltools/convert/common/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import abc
1111
import six
1212

13+
1314
@six.add_metaclass(abc.ABCMeta)
1415
class ModelContainer:
1516
__metaclass = abc.ABCMeta
@@ -41,6 +42,7 @@ def add_node(self, op_type, inputs, outputs, op_domain='', op_version=1, **attrs
4142
"""
4243
return
4344

45+
4446
@six.add_metaclass(abc.ABCMeta)
4547
class OperatorBase:
4648
__metaclass__ = abc.ABCMeta
@@ -77,6 +79,7 @@ def original_operator(self):
7779
"""
7880
pass
7981

82+
8083
@six.add_metaclass(abc.ABCMeta)
8184
class ScopeBase:
8285
__metaclass__ = abc.ABCMeta

onnxmltools/convert/common/optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def build_from_onnx(onnx_nodes, nchw_inputs, inputs, outputs):
117117
ln = LinkedNode(o_)
118118
view.append(ln)
119119
for var_ in o_.output:
120-
assert var_map.get(var_) is None
120+
if var_map.get(var_) is not None:
121+
raise RuntimeError("Duplicated output name (accross all nodes) '{0}'".format(var_))
121122
var_map[var_] = ln
122123

123124
additional_nodes = []

onnxmltools/convert/common/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def xgboost_installed():
120120
try:
121121
_LIB.XGBoosterDumpModelEx
122122
except AttributeError:
123-
# The version is now recent enough even though it is version 0.6.
123+
# The version is not recent enough even though it is version 0.6.
124124
# You need to install xgboost from github and not from pypi.
125125
return False
126126
from xgboost import __version__
@@ -290,7 +290,7 @@ def check_input_and_output_numbers(operator, input_count_range=None, output_coun
290290
if max_output_count is not None and len(operator.outputs) > max_output_count:
291291
raise RuntimeError(
292292
'For operator %s (type: %s), at most %s outputs(s) is(are) supported but we got %s output(s) which are %s' \
293-
% (operator.full_name, operator.type, max_output_count, len(operator.outputs), operator.outputs_full_names))
293+
% (operator.full_name, operator.type, max_output_count, len(operator.outputs), operator.output_full_names))
294294

295295

296296
def check_input_and_output_types(operator, good_input_types=None, good_output_types=None):

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _get_lightgbm_operator_name(model_type):
2121
'''
2222
Get operator name of the input argument
2323
24-
:param model_type: A scikit-learn object (e.g., SGDClassifier and Binarizer)
24+
:param model_type: A lightgbm object.
2525
:return: A string which stands for the type of the input model in our conversion framework
2626
'''
2727
if model_type not in lightgbm_operator_name_map:
@@ -60,7 +60,7 @@ def _parse_lightgbm(scope, model, inputs):
6060
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input
6161
model's type.
6262
:param scope: Scope object
63-
:param model: A scikit-learn object (e.g., OneHotEncoder and LogisticRegression)
63+
:param model: A lightgbm object
6464
:param inputs: A list of variables
6565
:return: The output variables produced by the input model
6666
'''

onnxmltools/convert/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,13 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_
6363
from skl2onnx.convert import convert_sklearn as convert_skl2onnx
6464
return convert_skl2onnx(model, name, initial_types, doc_string, target_opset,
6565
custom_conversion_functions, custom_shape_calculators)
66+
67+
68+
def convert_xgboost(*args, **kwargs):
69+
if not utils.xgboost_installed():
70+
raise RuntimeError('xgboost is not installed. Please install xgboost to use this feature.')
71+
72+
from .xgboost.convert import convert
73+
return convert(*args, **kwargs)
74+
75+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from .convert import convert

0 commit comments

Comments
 (0)