Skip to content

Commit d981dac

Browse files
authored
Switch to tf2onnx for tensorflow>=2.0 instead of keras2onnx (#492)
* switch to tf2onnx for tensorflow>=2.0 instead of keras2onnx * fix input_signature
1 parent cf660f0 commit d981dac

File tree

6 files changed

+130
-35
lines changed

6 files changed

+130
-35
lines changed

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,16 @@ If you choose to install `onnxmltools` from its source code, you must set the en
3535
## Dependencies
3636
This package relies on ONNX, NumPy, and ProtoBuf. If you are converting a model from scikit-learn, Core ML, Keras, LightGBM, SparkML, XGBoost, H2O, CatBoost or LibSVM, you will need an environment with the respective package installed from the list below:
3737
1. scikit-learn
38-
2. CoreMLTools
38+
2. CoreMLTools (version 3.1 or lower)
3939
3. Keras (version 2.0.8 or higher) with the corresponding Tensorflow version
40-
4. LightGBM (scikit-learn interface)
40+
4. LightGBM
4141
5. SparkML
42-
6. XGBoost (scikit-learn interface)
42+
6. XGBoost
4343
7. libsvm
4444
8. H2O
4545
9. CatBoost
4646

47-
ONNXMLTools has been tested with Python **3.5**, **3.6**, and **3.7**.
48-
Version 1.6.1 is the latest version supporting Python 2.7.
47+
ONNXMLTools is tested with Python **3.7+**.
4948

5049
# Examples
5150
If you want the converted ONNX model to be compatible with a certain ONNX version, please specify the target_opset parameter upon invoking the convert function. The following Keras model conversion example demonstrates this below. You can identify the mapping from ONNX Operator Sets (referred to as opsets) to ONNX releases in the [versioning documentation](https://github.com/onnx/onnx/blob/master/docs/Versioning.md#released-versions).

docs/api_summary.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@ in *onnxmltools*.
1414
Converters
1515
==========
1616

17+
.. autofunction:: onnxmltools.convert.h2o.catboost
18+
1719
.. autofunction:: onnxmltools.convert.coreml.convert
1820

21+
.. autofunction:: onnxmltools.convert.h2o.convert
22+
1923
.. autofunction:: onnxmltools.convert.keras.convert
2024

2125
.. autofunction:: onnxmltools.convert.lightgbm.convert
2226

2327
.. autofunction:: onnxmltools.convert.sklearn.convert
2428

29+
.. autofunction:: onnxmltools.convert.xgboost.convert
30+
2531
Utils
2632
=====
2733

docs/examples/plot_convert_keras.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
optimizer='sgd',
4949
metrics=['accuracy'])
5050
model.fit(X_train, y_train, epochs=5, batch_size=16)
51+
print("keras prediction")
52+
print(model.predict(X_test.astype(numpy.float32)))
5153

5254
###########################
5355
# Convert a model into ONNX
@@ -62,9 +64,10 @@
6264

6365
sess = rt.InferenceSession(onx.SerializeToString())
6466
input_name = sess.get_inputs()[0].name
65-
label_name = sess.get_outputs()[0].name
67+
output_name = sess.get_outputs()[0].name
6668
pred_onx = sess.run(
67-
[label_name], {input_name: X_test.astype(numpy.float32)})[0]
69+
[output_name], {input_name: X_test.astype(numpy.float32)})[0]
70+
print("ONNX prediction")
6871
print(pred_onx)
6972

7073
##################################

docs/index.rst

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,14 @@
44
onnxmltools: Convert your model into ONNX
55
=========================================
66

7-
.. list-table:
8-
:header-rows: 1
9-
:widths: 5 5
10-
* - Linux
11-
- Windows
12-
* - .. image:: https://travis-ci.org/onnx/onnxmltools.svg?branch=master
13-
:target: https://travis-ci.org/onnx/onnxmltools
14-
- .. image:: https://ci.appveyor.com/api/projects/status/d1xav3amubypje4n?svg=true
15-
:target: https://ci.appveyor.com/project/xadupre/onnxmltools
16-
177
ONNXMLTools enables you to convert models from different machine learning
188
toolkits into `ONNX <https://onnx.ai>`_.
199
Currently the following toolkits are supported:
2010

2111
* `Apple Core ML <https://developer.apple.com/documentation/coreml>`_,
2212
(`onnx-coreml <https://github.com/onnx/onnx-coreml>`_ does the reverse
23-
conversion from *onnx* to *Apple Core ML*),
13+
conversion from *onnx* to *Apple Core ML*) (up to version 3.1)
14+
* `catboost <https://catboost.ai/>`_
2415
* `h2o <http://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/intro.html>`_
2516
(a subset only)
2617
* `Keras <https://keras.io/>`_

onnxmltools/convert/common/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,16 @@ def hummingbird_installed():
1414
except ImportError:
1515
return False
1616

17+
18+
def tf2onnx_installed():
19+
"""
20+
Checks that *tf2onnx* is available.
21+
"""
22+
try:
23+
import tf2onnx # noqa F401
24+
return True
25+
except ImportError:
26+
return False
27+
28+
1729
from onnxconverter_common.utils import * # noqa

onnxmltools/convert/main.py

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import warnings
4+
from distutils.version import StrictVersion
35
import onnx
46
from .common import utils
5-
import warnings
67

78

89
def convert_coreml(model, name=None, initial_types=None, doc_string='', target_opset=None,
9-
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
10+
targeted_onnx=None, custom_conversion_functions=None, custom_shape_calculators=None):
11+
if targeted_onnx is not None:
12+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
1013
if not utils.coreml_installed():
1114
raise RuntimeError('coremltools is not installed. Please install coremltools to use this feature.')
1215

@@ -15,22 +18,93 @@ def convert_coreml(model, name=None, initial_types=None, doc_string='', target_o
1518
custom_conversion_functions, custom_shape_calculators)
1619

1720

18-
def convert_keras(model, name=None, initial_types=None, doc_string='',
19-
target_opset=None, targeted_onnx=onnx.__version__,
20-
channel_first_inputs=None, custom_conversion_functions=None, custom_shape_calculators=None,
21+
def convert_keras(model, name=None,
22+
initial_types=None,
23+
doc_string='',
24+
target_opset=None,
25+
targeted_onnx=None,
26+
channel_first_inputs=None,
27+
custom_conversion_functions=None,
28+
custom_shape_calculators=None,
2129
default_batch_size=1):
22-
if not utils.keras2onnx_installed():
23-
raise RuntimeError('keras2onnx is not installed. Please install it to use this feature.')
24-
25-
if custom_conversion_functions:
26-
warnings.warn('custom_conversion_functions is not supported any more. Please set it to None.')
27-
28-
from keras2onnx import convert_keras as convert
29-
return convert(model, name, doc_string, target_opset, channel_first_inputs)
30+
"""
31+
.. versionchanged:: 1.9.0
32+
The conversion is now using *tf2onnx*.
33+
"""
34+
if targeted_onnx is not None:
35+
warnings.warn("targeted_onnx is deprecated and unused. Use target_opset.", DeprecationWarning)
36+
import tensorflow as tf
37+
if StrictVersion(tf.__version__) < StrictVersion('2.0'):
38+
# Former converter for tensorflow<2.0.
39+
from keras2onnx import convert_keras as convert
40+
return convert(model, name, doc_string, target_opset, channel_first_inputs)
41+
else:
42+
# For tensorflow>=2.0, new converter based on tf2onnx.
43+
import tf2onnx
44+
45+
if not utils.tf2onnx_installed():
46+
raise RuntimeError('tf2onnx is not installed. Please install it to use this feature.')
47+
48+
if custom_conversion_functions is not None:
49+
warnings.warn('custom_conversion_functions is not supported any more. Please set it to None.')
50+
if custom_shape_calculators is not None:
51+
warnings.warn('custom_shape_calculators is not supported any more. Please set it to None.')
52+
if default_batch_size != 1:
53+
warnings.warn('default_batch_size is not supported any more. Please set it to 1.')
54+
if default_batch_size != 1:
55+
warnings.warn('default_batch_size is not supported any more. Please set it to 1.')
56+
57+
if initial_types is not None:
58+
from onnxconverter_common import (
59+
FloatTensorType, DoubleTensorType,
60+
Int64TensorType, Int32TensorType,
61+
StringTensorType, BooleanTensorType)
62+
spec = []
63+
for name, kind in initial_types:
64+
if isinstance(kind, FloatTensorType):
65+
dtype = tf.float32
66+
elif isinstance(kind, Int64TensorType):
67+
dtype = tf.int64
68+
elif isinstance(kind, Int32TensorType):
69+
dtype = tf.int32
70+
elif isinstance(kind, DoubleTensorType):
71+
dtype = tf.float64
72+
elif isinstance(kind, StringTensorType):
73+
dtype = tf.string
74+
elif isinstance(kind, BooleanTensorType):
75+
dtype = tf.bool
76+
else:
77+
raise TypeError(
78+
"Unexpected type %r, cannot infer tensorflow type." % type(kind))
79+
spec.append(tf.TensorSpec(tuple(kind.shape), dtype, name=name))
80+
input_signature = tuple(spec)
81+
else:
82+
input_signature = None
83+
84+
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
85+
model,
86+
input_signature=input_signature,
87+
opset=target_opset,
88+
custom_ops=None,
89+
custom_op_handlers=None,
90+
custom_rewriter=None,
91+
inputs_as_nchw=channel_first_inputs,
92+
extra_opset=None,
93+
shape_override=None,
94+
target=None,
95+
large_model=False,
96+
output_path=None)
97+
if external_tensor_storage is not None:
98+
warnings.warn("The current API does not expose the second result 'external_tensor_storage'. "
99+
"Use tf2onnx directly to get it.")
100+
model_proto.doc_string = doc_string
101+
return model_proto
30102

31103

32104
def convert_libsvm(model, name=None, initial_types=None, doc_string='', target_opset=None,
33-
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
105+
targeted_onnx=None, custom_conversion_functions=None, custom_shape_calculators=None):
106+
if targeted_onnx is not None:
107+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
34108
if not utils.libsvm_installed():
35109
raise RuntimeError('libsvm is not installed. Please install libsvm to use this feature.')
36110

@@ -51,8 +125,10 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target
51125

52126

53127
def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None,
54-
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
128+
targeted_onnx=None, custom_conversion_functions=None,
55129
custom_shape_calculators=None, without_onnx_ml=False, zipmap=True):
130+
if targeted_onnx is not None:
131+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
56132
if not utils.lightgbm_installed():
57133
raise RuntimeError('lightgbm is not installed. Please install lightgbm to use this feature.')
58134

@@ -63,7 +139,9 @@ def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target
63139

64140

65141
def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None,
66-
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
142+
targeted_onnx=None, custom_conversion_functions=None, custom_shape_calculators=None):
143+
if targeted_onnx is not None:
144+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
67145
if not utils.sklearn_installed():
68146
raise RuntimeError('scikit-learn is not installed. Please install scikit-learn to use this feature.')
69147

@@ -76,8 +154,10 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_
76154

77155

78156
def convert_sparkml(model, name=None, initial_types=None, doc_string='', target_opset=None,
79-
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
157+
targeted_onnx=None, custom_conversion_functions=None,
80158
custom_shape_calculators=None, spark_session=None):
159+
if targeted_onnx is not None:
160+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
81161
if not utils.sparkml_installed():
82162
raise RuntimeError('Spark is not installed. Please install Spark to use this feature.')
83163

@@ -87,6 +167,8 @@ def convert_sparkml(model, name=None, initial_types=None, doc_string='', target_
87167

88168

89169
def convert_xgboost(*args, **kwargs):
170+
if kwargs.get('targeted_onnx', None) is not None:
171+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
90172
if not utils.xgboost_installed():
91173
raise RuntimeError('xgboost is not installed. Please install xgboost to use this feature.')
92174

@@ -95,6 +177,8 @@ def convert_xgboost(*args, **kwargs):
95177

96178

97179
def convert_h2o(*args, **kwargs):
180+
if kwargs.get('targeted_onnx', None) is not None:
181+
warnings.warn("targeted_onnx is deprecated. Use target_opset.", DeprecationWarning)
98182
if not utils.h2o_installed():
99183
raise RuntimeError('h2o is not installed. Please install h2o to use this feature.')
100184

0 commit comments

Comments
 (0)