77from ..proto import onnx
88from .common import utils
99import warnings
10+ import importlib
11+
1012
1113def convert_coreml (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
12- targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
14+ targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
1315 if not utils .coreml_installed ():
1416 raise RuntimeError ('coremltools is not installed. Please install coremltools to use this feature.' )
1517
@@ -33,7 +35,7 @@ def convert_keras(model, name=None, initial_types=None, doc_string='',
3335
3436
3537def convert_libsvm (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
36- targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
38+ targeted_onnx = onnx .__version__ , custom_conversion_functions = None , custom_shape_calculators = None ):
3739 if not utils .libsvm_installed ():
3840 raise RuntimeError ('libsvm is not installed. Please install libsvm to use this feature.' )
3941
@@ -62,7 +64,8 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_
6264
6365 from skl2onnx .convert import convert_sklearn as convert_skl2onnx
6466 return convert_skl2onnx (model , name , initial_types , doc_string , target_opset ,
65- custom_conversion_functions , custom_shape_calculators )
67+ custom_conversion_functions , custom_shape_calculators )
68+
6669
6770def convert_sparkml (model , name = None , initial_types = None , doc_string = '' , target_opset = None ,
6871 targeted_onnx = onnx .__version__ , custom_conversion_functions = None ,
@@ -74,18 +77,6 @@ def convert_sparkml(model, name=None, initial_types=None, doc_string='', target_
7477 return convert (model , name , initial_types , doc_string , target_opset , targeted_onnx ,
7578 custom_conversion_functions , custom_shape_calculators , spark_session )
7679
77- def convert_tensorflow (frozen_graph_def ,
78- name = None , input_names = None , output_names = None ,
79- doc_string = '' ,
80- target_opset = None ,
81- channel_first_inputs = None ,
82- debug_mode = False , custom_op_conversions = None ):
83- if not utils .keras2onnx_installed ():
84- raise RuntimeError ('keras2onnx is not installed. Please install it to use this feature.' )
85-
86- from keras2onnx import convert_tensorflow as convert
87- return convert (frozen_graph_def , name , input_names , output_names , doc_string ,
88- target_opset , channel_first_inputs , debug_mode , custom_op_conversions )
8980
9081def convert_xgboost (* args , ** kwargs ):
9182 if not utils .xgboost_installed ():
@@ -94,9 +85,91 @@ def convert_xgboost(*args, **kwargs):
9485 from .xgboost .convert import convert
9586 return convert (* args , ** kwargs )
9687
88+
9789def convert_h2o (* args , ** kwargs ):
9890 if not utils .h2o_installed ():
9991 raise RuntimeError ('h2o is not installed. Please install h2o to use this feature.' )
10092
10193 from .h2o .convert import convert
10294 return convert (* args , ** kwargs )
95+
96+
97+ def _collect_input_nodes (graph , outputs ):
98+ nodes_to_keep = set ()
99+ input_nodes = set ()
100+ node_inputs = [graph .get_tensor_by_name (ts_ ).op for ts_ in outputs ]
101+ while node_inputs :
102+ nd_ = node_inputs [0 ]
103+ del node_inputs [0 ]
104+ if nd_ .type in ['Placeholder' , "PlaceholderV2" , 'PlaceholderWithDefault' ]:
105+ input_nodes .add (nd_ )
106+ if nd_ in nodes_to_keep :
107+ continue
108+
109+ nodes_to_keep .add (nd_ )
110+ node_inputs .extend (in_ .op for in_ in nd_ .inputs )
111+
112+ return input_nodes , nodes_to_keep
113+
114+
115+ def _convert_tf_wrapper (frozen_graph_def ,
116+ name = None , input_names = None , output_names = None ,
117+ doc_string = '' ,
118+ target_opset = None ,
119+ channel_first_inputs = None ,
120+ debug_mode = False , custom_op_conversions = None ):
121+ """
122+ convert a tensorflow graph def into a ONNX model proto, just like how keras does.
123+ :param graph_def: the frozen tensorflow graph
124+ :param name: the converted onnx model internal name
125+ :param input_names: the inputs name list of the model
126+ :param output_names: the output name list of the model
127+ :param doc_string: doc string
128+ :param target_opset: the targeted onnx model opset
129+ :param channel_first_inputs: A list of channel first input (not supported yet)
130+ :param debug_mode: will enable the log and try to convert as much as possible on conversion
131+ :return an ONNX ModelProto
132+ """
133+ import tensorflow as tf
134+ import tf2onnx
135+
136+ if target_opset is None :
137+ target_opset = onnx .defs .onnx_opset_version ()
138+
139+ if not doc_string :
140+ doc_string = "converted from {}" .format (name )
141+
142+ tf_graph_def = tf2onnx .tfonnx .tf_optimize (input_names , output_names , frozen_graph_def , True )
143+ with tf .Graph ().as_default () as tf_graph :
144+ tf .import_graph_def (tf_graph_def , name = '' )
145+
146+ if not input_names :
147+ input_nodes = list (_collect_input_nodes (tf_graph , output_names )[0 ])
148+ input_names = [nd_ .outputs [0 ].name for nd_ in input_nodes ]
149+ g = tf2onnx .tfonnx .process_tf_graph (tf_graph ,
150+ continue_on_error = debug_mode ,
151+ opset = target_opset ,
152+ custom_op_handlers = custom_op_conversions ,
153+ inputs_as_nchw = channel_first_inputs ,
154+ output_names = output_names ,
155+ input_names = input_names )
156+
157+ onnx_graph = tf2onnx .optimizer .optimize_graph (g )
158+ model_proto = onnx_graph .make_model (doc_string )
159+
160+ return model_proto
161+
162+
163+ def convert_tensorflow (frozen_graph_def ,
164+ name = None , input_names = None , output_names = None ,
165+ doc_string = '' ,
166+ target_opset = None ,
167+ channel_first_inputs = None ,
168+ debug_mode = False , custom_op_conversions = None ):
169+ try :
170+ importlib .import_module ('tf2onnx' )
171+ except (ImportError , ModuleNotFoundError ) as e :
172+ raise RuntimeError ('tf2onnx is not installed, please install it before calling this function.' )
173+
174+ return _convert_tf_wrapper (frozen_graph_def , name , input_names , output_names , doc_string ,
175+ target_opset , channel_first_inputs , debug_mode , custom_op_conversions )
0 commit comments