Skip to content

Commit 1767f05

Browse files
authored
support parameter shape_override (#497)
1 parent 8e52e09 commit 1767f05

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

onnxmltools/convert/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def _convert_tf_wrapper(frozen_graph_def,
209209
doc_string='',
210210
target_opset=None,
211211
channel_first_inputs=None,
212-
debug_mode=False, custom_op_conversions=None):
212+
debug_mode=False, custom_op_conversions=None,
213+
**kwargs):
213214
"""
214215
convert a tensorflow graph def into a ONNX model proto, just like how keras does.
215216
:param graph_def: the frozen tensorflow graph
@@ -220,6 +221,8 @@ def _convert_tf_wrapper(frozen_graph_def,
220221
:param target_opset: the targeted onnx model opset
221222
:param channel_first_inputs: A list of channel first input (not supported yet)
222223
:param debug_mode: will enable the log and try to convert as much as possible on conversion
224+
:param kwargs: additional parameters of function `processs_tf_graph
225+
<https://github.com/onnx/tensorflow-onnx#creating-custom-op-mappings-from-python>`_
223226
:return an ONNX ModelProto
224227
"""
225228
import tensorflow as tf
@@ -244,7 +247,8 @@ def _convert_tf_wrapper(frozen_graph_def,
244247
custom_op_handlers=custom_op_conversions,
245248
inputs_as_nchw=channel_first_inputs,
246249
output_names=output_names,
247-
input_names=input_names)
250+
input_names=input_names,
251+
**kwargs)
248252

249253
onnx_graph = tf2onnx.optimizer.optimize_graph(g)
250254
model_proto = onnx_graph.make_model(doc_string)
@@ -257,10 +261,12 @@ def convert_tensorflow(frozen_graph_def,
257261
doc_string='',
258262
target_opset=None,
259263
channel_first_inputs=None,
260-
debug_mode=False, custom_op_conversions=None):
264+
debug_mode=False, custom_op_conversions=None,
265+
**kwargs):
261266
import pkgutil
262267
if not pkgutil.find_loader('tf2onnx'):
263268
raise RuntimeError('tf2onnx is not installed, please install it before calling this function.')
264269

265270
return _convert_tf_wrapper(frozen_graph_def, name, input_names, output_names, doc_string,
266-
target_opset, channel_first_inputs, debug_mode, custom_op_conversions)
271+
target_opset, channel_first_inputs, debug_mode, custom_op_conversions,
272+
**kwargs)

0 commit comments

Comments
 (0)