Skip to content

Commit 5e39d16

Browse files
Allow -1 to indicate unknown dimensions for shape override (#1300)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent a607275 commit 5e39d16

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ The target onnx file path.
168168

169169
#### --inputs, --outputs
170170

171-
TensorFlow model's input/output names, which can be found with [summarize graph tool](#summarize_graph). Those names typically end with ```:0```, for example ```--inputs input0:0,input1:0```. Inputs and outputs are ***not*** needed for models in saved-model format. Some models specify placeholders with unknown ranks and dims which can not be mapped to onnx. In those cases one can add the shape after the input name inside `[]`, for example `--inputs X:0[1,28,28,3]`
171+
TensorFlow model's input/output names, which can be found with [summarize graph tool](#summarize_graph). Those names typically end with ```:0```, for example ```--inputs input0:0,input1:0```. Inputs and outputs are ***not*** needed for models in saved-model format. Some models specify placeholders with unknown ranks and dims which can not be mapped to onnx. In those cases one can add the shape after the input name inside `[]`, for example `--inputs X:0[1,28,28,3]`. Use -1 to indicate unknown dimensions.
172172

173173
#### --inputs-as-nchw
174174

tf2onnx/shape_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def infer_shape_for_op(op):
100100
op.outputs[0].set_shape(new_shape)
101101
logger.debug("set placeholder op [%s] with new shape %s", op.outputs[0].name, new_shape)
102102
return True
103-
logger.warning("Shape of placeholder %s is unknown, treated it as a scalar", op.name)
103+
logger.warning("Shape of placeholder '%s' is unknown, treated it as a scalar. Please use the --input flag "
104+
"and append the shape to the input name if this input is not a scalar.", op.name)
104105
op.outputs[0].set_shape([])
105106
return True
106107

tf2onnx/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def split_nodename_and_shape(name):
107107
for i in range(1, len(splits), 3):
108108
inputs.append(splits[i])
109109
if splits[i + 1] is not None:
110-
shapes[splits[i]] = [int(n) for n in splits[i + 1][1:-1].split(",")]
110+
shape = [int(n) for n in splits[i + 1][1:-1].split(",")]
111+
shape = [n if n >= 0 else None for n in shape]
112+
shapes[splits[i]] = shape
111113
if not shapes:
112114
shapes = None
113115
return inputs, shapes

0 commit comments

Comments
 (0)