Skip to content

Commit fa16f33

Browse files
Added support for PlaceholderWithDefault ops with computed defaults (#1268)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 90b1cd6 commit fa16f33

File tree

6 files changed

+67
-10
lines changed

6 files changed

+67
-10
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ TensorFlow model's input/output names, which can be found with [summarize graph
172172

173173
By default we preserve the image format of inputs (`nchw` or `nhwc`) as given in the TensorFlow model. If your hosts (for example windows) native format nchw and the model is written for nhwc, ```--inputs-as-nchw``` tensorflow-onnx will transpose the input. Doing so is convenient for the application and the converter in many cases can optimize the transpose away. For example ```--inputs input0:0,input1:0 --inputs-as-nchw input0:0``` assumes that images are passed into ```input0:0``` as nchw while the TensorFlow model given uses nhwc.
174174

175+
#### --ignore_default, --use_default
176+
177+
ONNX requires default values for graph inputs to be constant, while Tensorflow's PlaceholderWithDefault op accepts computed defaults. To convert such models, pass a comma-separated list of node names to the ignore_default and/or use_default flags. PlaceholderWithDefault nodes with matching names will be replaced with Placeholder or Identity ops, respectively.
178+
175179
#### --opset
176180

177181
By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
@@ -289,6 +293,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
289293
custom_rewriter=None, extra_opset=None,
290294
shape_override=None, inputs_as_nchw=None,
291295
input_names=None, output_names=None,
296+
ignore_default=None, use_default=None,
292297
const_node_values=None):
293298
"""Convert tensorflow graph to onnx graph.
294299
Args:
@@ -304,6 +309,8 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
304309
inputs_as_nchw: transpose inputs in list from nchw to nchw
305310
input_names: list of input node names in graph, input name format as node_name:port_id
306311
output_names: list of output node names in graph, output name format as node_name:port_id
312+
ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops
313+
use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops
307314
const_node_values: an optional dict mapping node names to tensor values
308315
Return:
309316
onnx graph

tests/test_backend.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
2424
from tf2onnx import constants, utils
2525
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
26-
from tf2onnx.tf_loader import is_tf2, tf_placeholder_with_default
26+
from tf2onnx.tf_loader import is_tf2, tf_placeholder_with_default, tf_placeholder
2727
from tf2onnx.onnx_opset.signal import make_dft_constant
2828

2929
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
@@ -728,6 +728,33 @@ def func():
728728
x_feed_val = np.array([11.0, 22.0, -33.0, -44.0], dtype=np.float32).reshape((2, 2))
729729
self._run_test_case(func, [_OUTPUT], {_INPUT: x_feed_val}, as_session=True, premade_placeholders=True)
730730

731+
def test_placeholder_with_default_computed_use_default(self):
732+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
733+
y_val = np.array([2.0, -4.0, 6.0, -8.0], dtype=np.float32).reshape((2, 2))
734+
def func():
735+
x = tf_placeholder(tf.float32, x_val.shape, name=_TFINPUT)
736+
y = tf_placeholder(tf.float32, y_val.shape, name=_TFINPUT1)
737+
total = tf.add(x, y)
738+
z = tf_placeholder_with_default(total, x_val.shape, name=_TFINPUT2)
739+
total2 = tf.add(total, z)
740+
return tf.identity(total2, name=_TFOUTPUT)
741+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}, as_session=True,
742+
premade_placeholders=True, process_args={'use_default': [_TFINPUT2]})
743+
744+
def test_placeholder_with_default_computed_ignore_default(self):
745+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
746+
y_val = np.array([2.0, -4.0, 6.0, -8.0], dtype=np.float32).reshape((2, 2))
747+
z_val = np.array([3.0, 6.0, 9.0, 10.0], dtype=np.float32).reshape((2, 2))
748+
def func():
749+
x = tf_placeholder(tf.float32, x_val.shape, name=_TFINPUT)
750+
y = tf_placeholder(tf.float32, y_val.shape, name=_TFINPUT1)
751+
total = tf.add(x, y)
752+
z = tf_placeholder_with_default(total, x_val.shape, name=_TFINPUT2)
753+
total2 = tf.add(total, z)
754+
return tf.identity(total2, name=_TFOUTPUT)
755+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}, as_session=True,
756+
premade_placeholders=True, process_args={'ignore_default': [_TFINPUT2]})
757+
731758
@check_onnxruntime_incompatibility("Add")
732759
def test_add_bcast(self):
733760
x1_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/convert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def get_args():
5959
parser.add_argument("--output", help="output model file")
6060
parser.add_argument("--inputs", help="model input_names")
6161
parser.add_argument("--outputs", help="model output_names")
62+
parser.add_argument("--ignore_default", help="comma-separated list of names of PlaceholderWithDefault "
63+
"ops to change into Placeholder ops")
64+
parser.add_argument("--use_default", help="comma-separated list of names of PlaceholderWithDefault ops to "
65+
"change into Identity ops using their default value")
6266
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
6367
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
6468
parser.add_argument("--extra_opset", default=None,
@@ -89,6 +93,10 @@ def get_args():
8993
args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs)
9094
if args.outputs:
9195
args.outputs = args.outputs.split(",")
96+
if args.ignore_default:
97+
args.ignore_default = args.ignore_default.split(",")
98+
if args.use_default:
99+
args.use_default = args.use_default.split(",")
92100
if args.inputs_as_nchw:
93101
args.inputs_as_nchw = args.inputs_as_nchw.split(",")
94102
if args.target:
@@ -172,6 +180,8 @@ def main():
172180
input_names=inputs,
173181
output_names=outputs,
174182
inputs_as_nchw=args.inputs_as_nchw,
183+
ignore_default=args.ignore_default,
184+
use_default=args.use_default,
175185
const_node_values=const_node_values,
176186
initialized_tables=initialized_tables)
177187

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,8 @@ def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
10771077
if op.type == "PlaceholderWithDefault":
10781078
utils.make_sure(op.inputs[0] is not None, "Cannot find node with output {}".format(op.input[0]))
10791079
utils.make_sure(op.inputs[0].is_const(),
1080-
"non-const default value for PlaceholderWithDefault is not supported.")
1080+
"non-const default value for PlaceholderWithDefault node '%s' is not supported. "
1081+
"Use the --use_default or --ignore_default flags to convert this node.", op.name)
10811082
# copy the tensor value, set its name to current node's output, add as initializer
10821083
value = op.inputs[0].get_tensor_value(as_list=False)
10831084
tensor = numpy_helper.from_array(value, op.output[0])

tf2onnx/tf_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
317317
n.attr['key_dtype'].type = key_dtype
318318
n.attr['value_dtype'].type = val_dtype
319319

320-
def tflist_to_onnx(g, shape_override, const_node_values=None):
320+
def tflist_to_onnx(g, shape_override, const_node_values=None, ignore_default=None, use_default=None):
321321
"""
322322
Convert the tf-node list into an onnx graph with minimal rewrites so
323323
we can use the onnx graph as intermediate graph.
@@ -395,11 +395,20 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
395395
else:
396396
attr[a] = get_tf_node_attr(node, a)
397397

398+
node_type = node.type
399+
input_names = [i.name for i in node.inputs]
400+
output_names = [i.name for i in node.outputs]
401+
402+
if node_type == 'PlaceholderWithDefault':
403+
if ignore_default and node.name in ignore_default:
404+
node_type = 'Placeholder'
405+
input_names = []
406+
elif use_default and node.name in use_default:
407+
node_type = 'Identity'
408+
398409
if takeit:
399410
try:
400-
input_names = [i.name for i in node.inputs]
401-
output_names = [i.name for i in node.outputs]
402-
onnx_node = helper.make_node(node.type, input_names, output_names, name=node.name, **attr)
411+
onnx_node = helper.make_node(node_type, input_names, output_names, name=node.name, **attr)
403412
onnx_nodes.append(onnx_node)
404413
except Exception as ex:
405414
logger.error("pass1 convert failed for %s, ex=%s", node, ex)
@@ -408,8 +417,8 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
408417
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
409418

410419

411-
def tensorflow_to_onnx(graph, shape_override, const_node_values=None):
420+
def tensorflow_to_onnx(graph, shape_override, const_node_values=None, ignore_default=None, use_default=None):
412421
"""
413422
Load tensorflow graph and do a conversion.
414423
"""
415-
return tflist_to_onnx(graph, shape_override, const_node_values)
424+
return tflist_to_onnx(graph, shape_override, const_node_values, ignore_default, use_default)

tf2onnx/tfonnx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def run_rewriters(g, funcs, continue_on_error):
367367
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
368368
opset=None, custom_op_handlers=None, custom_rewriter=None,
369369
extra_opset=None, shape_override=None, inputs_as_nchw=None,
370-
input_names=None, output_names=None, is_subgraph=False, const_node_values=None,
370+
input_names=None, output_names=None, ignore_default=None, use_default=None,
371+
is_subgraph=False, const_node_values=None,
371372
initialized_tables=None):
372373
"""Convert tensorflow graph to onnx graph.
373374
Args:
@@ -383,6 +384,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
383384
inputs_as_nchw: transpose inputs in list from nchw to nhwc
384385
input_names: list of input node names in graph, input name format as node_name:port_id
385386
output_names: list of output node names in graph, output name format as node_name:port_id
387+
ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops
388+
use_default: list of node names of PlaceholderWithDefault ops to change into Identity ops using the default
386389
const_node_values: a dict returned by compress_graph_def mapping node names to tensor values
387390
initialized_tables: mapping from table shared_names to tuple of keys and values of table
388391
Return:
@@ -416,7 +419,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
416419
outputs_to_values, outputs_to_dtypes = compute_const_folding_using_tf(tf_graph, const_node_values, output_names)
417420

418421
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
419-
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)
422+
tensorflow_to_onnx(tf_graph, shape_override, const_node_values, ignore_default, use_default)
420423
if not is_subgraph:
421424
# make tf2onnx internal subgraphs from the tensorflow subgraphs
422425
ordered_func = resolve_functions(tf_graph)

0 commit comments

Comments
 (0)