Skip to content

Commit 7e5dc5e

Browse files
committed
shape inference support ExpandDims
1 parent e3cbb3f commit 7e5dc5e

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tf2onnx/shape_inference.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,22 @@ def infer_shape_for_node(g, node):
175175

176176
g.set_shape(node.output[0], new_shape)
177177
log.debug("set %s node [%s] with new shape %s", node.type, node.output[0], new_shape)
178+
return True
179+
180+
if node.type == "ExpandDims":
181+
# https://www.tensorflow.org/api_docs/python/tf/expand_dims
182+
input_shape = g.get_shape(node.input[0])
183+
dim_node = node.inputs[1]
184+
if input_shape is None or not dim_node.is_const():
185+
return False
178186

187+
dim = dim_node.get_tensor_value()
188+
if dim < 0:
189+
dim = dim + len(input_shape) + 1
190+
191+
new_shape = input_shape[:dim] + [1] + input_shape[dim:]
192+
g.set_shape(node.output[0], new_shape)
193+
log.debug("set [%s] with new shape %s", node.output[0], new_shape)
179194
return True
180195

181196
return False

0 commit comments

Comments
 (0)