Skip to content

Commit 446494e

Browse files
correct input/output name parsing and placeholder shape for tfjs (#1723)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 5db12e0 commit 446494e

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

tf2onnx/tfjs_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
178178
# The second output of merge is a scalar int indicating which input was selected
179179
return [non_none, []]
180180

181+
if node_def.op == "Placeholder":
182+
shape = None
183+
if 'shape' in node_def.attr:
184+
shape = [d.size for d in node_def.attr['shape'].shape.dim]
185+
shape = [None if d == -1 else d for d in shape]
186+
if len(shape) == 0:
187+
# According to TF docs, "If the shape has 0 dimensions, the shape is unconstrained."
188+
shape = None
189+
return [shape]
190+
181191
del node_def.input[:]
182192
node_def.name = "node"
183193
if "_class" in node_def.attr:
@@ -283,11 +293,19 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_over
283293
utils.make_sure(len(weights_data) == i, "Total weight bytes %d doesn't match read bytes %d", len(weights_data), i)
284294
topology = model['modelTopology']
285295

296+
tensors_to_rename = {}
286297
if output_names is None and 'signature' in model:
287-
output_names = list(model['signature']['outputs'].keys())
298+
outputs = model['signature'].get('outputs')
299+
inputs = model['signature'].get('inputs')
300+
if outputs is not None:
301+
output_names = [v['name'] for v in outputs.values()]
302+
tensors_to_rename.update({v['name']: k for k, v in outputs.items()})
303+
if inputs is not None:
304+
tensors_to_rename.update({v['name']: k for k, v in inputs.items()})
288305

289306
main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override,
290307
ignore_default, use_default)
308+
main_g.rename_tensors(tensors_to_rename)
291309
subgraphs = []
292310
funcs = sort_tfjs_functions(topology.get('library', {}).get('function', []))
293311
for func in funcs:
@@ -303,7 +321,7 @@ def read_tfjs_weight(weight, weights_data, offset):
303321
name = weight['name']
304322
count = np.product(weight['shape'], dtype=np.int64)
305323
if weight['dtype'] == 'string':
306-
num_strings = np.product(weight['shape'])
324+
num_strings = np.prod(weight['shape'], dtype=np.int64)
307325
string_list, num_bytes = read_string_weight(weights_data, offset, num_strings)
308326
np_arr = np.array(string_list).reshape(weight['shape'])
309327
return name, np_arr, num_bytes
@@ -428,10 +446,11 @@ def update_shapes(new_shapes):
428446
# This op isn't in tensorflow but can be converted to a TF op
429447
op_type = "_FusedDepthwiseConv2dNative"
430448
err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative"
431-
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
432-
del tf_attr['explicit_paddings']
433-
del onnx_attr['explicit_paddings']
434-
del node_def.attr['explicit_paddings']
449+
if "explicit_paddings" in tf_attr:
450+
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
451+
del tf_attr['explicit_paddings']
452+
del onnx_attr['explicit_paddings']
453+
del node_def.attr['explicit_paddings']
435454
node_def.op = op_type
436455

437456
input_names = [inp for inp in node.get('input', []) if not inp.startswith('^')]
@@ -465,6 +484,10 @@ def update_shapes(new_shapes):
465484
onnx_node = helper.make_node(op_type, input_names, output_names, name=node_name, **onnx_attr)
466485
onnx_nodes.append(onnx_node)
467486

487+
for inp in graph_inputs:
488+
if output_shapes[inp] is None:
489+
logger.warning("Input %s has unknown shape. Specify shape with --inputs flag.", inp)
490+
468491
dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()}
469492
if graph_outputs is None:
470493
output_to_node = {out: node.name for node in onnx_nodes for out in node.output}

0 commit comments

Comments
 (0)