11# SPDX-License-Identifier: Apache-2.0
22
3+ import warnings
4+ from distutils .version import StrictVersion
35import onnx
46from .common import utils
5- import warnings
67
78
89def convert_coreml (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
9- targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
10+ targeted_onnx = None , custom_conversion_functions = None , custom_shape_calculators = None ):
11+ if targeted_onnx is not None :
12+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
1013 if not utils .coreml_installed ():
1114 raise RuntimeError ('coremltools is not installed. Please install coremltools to use this feature.' )
1215
@@ -15,22 +18,93 @@ def convert_coreml(model, name=None, initial_types=None, doc_string='', target_o
1518 custom_conversion_functions , custom_shape_calculators )
1619
1720
18- def convert_keras (model , name = None , initial_types = None , doc_string = '' ,
19- target_opset = None , targeted_onnx = onnx .__version__ ,
20- channel_first_inputs = None , custom_conversion_functions = None , custom_shape_calculators = None ,
21+ def convert_keras (model , name = None ,
22+ initial_types = None ,
23+ doc_string = '' ,
24+ target_opset = None ,
25+ targeted_onnx = None ,
26+ channel_first_inputs = None ,
27+ custom_conversion_functions = None ,
28+ custom_shape_calculators = None ,
2129 default_batch_size = 1 ):
22- if not utils .keras2onnx_installed ():
23- raise RuntimeError ('keras2onnx is not installed. Please install it to use this feature.' )
24-
25- if custom_conversion_functions :
26- warnings .warn ('custom_conversion_functions is not supported any more. Please set it to None.' )
27-
28- from keras2onnx import convert_keras as convert
29- return convert (model , name , doc_string , target_opset , channel_first_inputs )
30+ """
31+ .. versionchanged:: 1.9.0
32+ The conversion is now using *tf2onnx*.
33+ """
34+ if targeted_onnx is not None :
35+ warnings .warn ("targeted_onnx is deprecated and unused. Use target_opset." , DeprecationWarning )
36+ import tensorflow as tf
37+ if StrictVersion (tf .__version__ ) < StrictVersion ('2.0' ):
38+ # Former converter for tensorflow<2.0.
39+ from keras2onnx import convert_keras as convert
40+ return convert (model , name , doc_string , target_opset , channel_first_inputs )
41+ else :
42+ # For tensorflow>=2.0, new converter based on tf2onnx.
43+ import tf2onnx
44+
45+ if not utils .tf2onnx_installed ():
46+ raise RuntimeError ('tf2onnx is not installed. Please install it to use this feature.' )
47+
48+ if custom_conversion_functions is not None :
49+ warnings .warn ('custom_conversion_functions is not supported any more. Please set it to None.' )
50+ if custom_shape_calculators is not None :
51+ warnings .warn ('custom_shape_calculators is not supported any more. Please set it to None.' )
52+ if default_batch_size != 1 :
53+ warnings .warn ('default_batch_size is not supported any more. Please set it to 1.' )
54+ if default_batch_size != 1 :
55+ warnings .warn ('default_batch_size is not supported any more. Please set it to 1.' )
56+
57+ if initial_types is not None :
58+ from onnxconverter_common import (
59+ FloatTensorType , DoubleTensorType ,
60+ Int64TensorType , Int32TensorType ,
61+ StringTensorType , BooleanTensorType )
62+ spec = []
63+ for name , kind in initial_types :
64+ if isinstance (kind , FloatTensorType ):
65+ dtype = tf .float32
66+ elif isinstance (kind , Int64TensorType ):
67+ dtype = tf .int64
68+ elif isinstance (kind , Int32TensorType ):
69+ dtype = tf .int32
70+ elif isinstance (kind , DoubleTensorType ):
71+ dtype = tf .float64
72+ elif isinstance (kind , StringTensorType ):
73+ dtype = tf .string
74+ elif isinstance (kind , BooleanTensorType ):
75+ dtype = tf .bool
76+ else :
77+ raise TypeError (
78+ "Unexpected type %r, cannot infer tensorflow type." % type (kind ))
79+ spec .append (tf .TensorSpec (tuple (kind .shape ), dtype , name = name ))
80+ input_signature = tuple (spec )
81+ else :
82+ input_signature = None
83+
84+ model_proto , external_tensor_storage = tf2onnx .convert .from_keras (
85+ model ,
86+ input_signature = input_signature ,
87+ opset = target_opset ,
88+ custom_ops = None ,
89+ custom_op_handlers = None ,
90+ custom_rewriter = None ,
91+ inputs_as_nchw = channel_first_inputs ,
92+ extra_opset = None ,
93+ shape_override = None ,
94+ target = None ,
95+ large_model = False ,
96+ output_path = None )
97+ if external_tensor_storage is not None :
98+ warnings .warn ("The current API does not expose the second result 'external_tensor_storage'. "
99+ "Use tf2onnx directly to get it." )
100+ model_proto .doc_string = doc_string
101+ return model_proto
30102
31103
32104def convert_libsvm (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
33- targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
105+ targeted_onnx = None , custom_conversion_functions = None , custom_shape_calculators = None ):
106+ if targeted_onnx is not None :
107+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
34108 if not utils .libsvm_installed ():
35109 raise RuntimeError ('libsvm is not installed. Please install libsvm to use this feature.' )
36110
@@ -51,8 +125,10 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target
51125
52126
53127def convert_lightgbm (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
54- targeted_onnx = onnx . __version__ , custom_conversion_functions = None ,
128+ targeted_onnx = None , custom_conversion_functions = None ,
55129 custom_shape_calculators = None , without_onnx_ml = False , zipmap = True ):
130+ if targeted_onnx is not None :
131+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
56132 if not utils .lightgbm_installed ():
57133 raise RuntimeError ('lightgbm is not installed. Please install lightgbm to use this feature.' )
58134
@@ -63,7 +139,9 @@ def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target
63139
64140
65141def convert_sklearn (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
66- targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
142+ targeted_onnx = None , custom_conversion_functions = None , custom_shape_calculators = None ):
143+ if targeted_onnx is not None :
144+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
67145 if not utils .sklearn_installed ():
68146 raise RuntimeError ('scikit-learn is not installed. Please install scikit-learn to use this feature.' )
69147
@@ -76,8 +154,10 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_
76154
77155
78156def convert_sparkml (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
79- targeted_onnx = onnx . __version__ , custom_conversion_functions = None ,
157+ targeted_onnx = None , custom_conversion_functions = None ,
80158 custom_shape_calculators = None , spark_session = None ):
159+ if targeted_onnx is not None :
160+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
81161 if not utils .sparkml_installed ():
82162 raise RuntimeError ('Spark is not installed. Please install Spark to use this feature.' )
83163
@@ -87,6 +167,8 @@ def convert_sparkml(model, name=None, initial_types=None, doc_string='', target_
87167
88168
89169def convert_xgboost (* args , ** kwargs ):
170+ if kwargs .get ('targeted_onnx' , None ) is not None :
171+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
90172 if not utils .xgboost_installed ():
91173 raise RuntimeError ('xgboost is not installed. Please install xgboost to use this feature.' )
92174
@@ -95,6 +177,8 @@ def convert_xgboost(*args, **kwargs):
95177
96178
97179def convert_h2o (* args , ** kwargs ):
180+ if kwargs .get ('targeted_onnx' , None ) is not None :
181+ warnings .warn ("targeted_onnx is deprecated. Use target_opset." , DeprecationWarning )
98182 if not utils .h2o_installed ():
99183 raise RuntimeError ('h2o is not installed. Please install h2o to use this feature.' )
100184
0 commit comments