Skip to content

Commit e68ea2e

Browse files
authored
Merge pull request #371 from nbcsm/expanddims
shape inference support ExpandDims
2 parents f9d7afb + e9569f4 commit e68ea2e

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

tf2onnx/shape_inference.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,22 @@ def infer_shape_for_node(g, node):
175175

176176
g.set_shape(node.output[0], new_shape)
177177
log.debug("set %s node [%s] with new shape %s", node.type, node.output[0], new_shape)
178+
return True
179+
180+
if node.type == "ExpandDims":
181+
# https://www.tensorflow.org/api_docs/python/tf/expand_dims
182+
input_shape = g.get_shape(node.input[0])
183+
dim_node = node.inputs[1]
184+
if input_shape is None or not dim_node.is_const():
185+
return False
178186

187+
dim = dim_node.get_tensor_value()
188+
if dim < 0:
189+
dim = dim + len(input_shape) + 1
190+
191+
new_shape = input_shape[:dim] + [1] + input_shape[dim:]
192+
g.set_shape(node.output[0], new_shape)
193+
log.debug("set [%s] with new shape %s", node.output[0], new_shape)
179194
return True
180195

181196
return False

tf2onnx/tfonnx.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,9 +881,10 @@ def expanddims_op(ctx, node, name, args):
881881
dim_node = node.inputs[1]
882882
if dim_node.is_const():
883883
node.type = "Unsqueeze"
884-
input_rank = len(ctx.get_shape(node.input[0]))
885884
dim = dim_node.get_tensor_value()
886-
dim = dim + input_rank + 1 if dim < 0 else dim
885+
if dim < 0:
886+
input_rank = len(ctx.get_shape(node.input[0]))
887+
dim = dim + input_rank + 1
887888
node.set_attr("axes", [dim])
888889
ctx.remove_input(node, node.input[1])
889890
return
@@ -907,9 +908,10 @@ def expanddims_op7(ctx, node, name, args):
907908
dim_node = node.inputs[1]
908909
if dim_node.is_const():
909910
node.type = "Unsqueeze"
910-
input_rank = len(ctx.get_shape(node.input[0]))
911911
dim = dim_node.get_tensor_value()
912-
dim = dim + input_rank + 1 if dim < 0 else dim
912+
if dim < 0:
913+
input_rank = len(ctx.get_shape(node.input[0]))
914+
dim = dim + input_rank + 1
913915
node.set_attr("axes", [dim])
914916
ctx.remove_input(node, node.input[1])
915917
return

0 commit comments

Comments
 (0)