Skip to content

Commit c3f04d4

Browse files
committed
enhance shape inference for op "Gather"
1 parent 66801b9 commit c3f04d4

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tf2onnx/shape_inference.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ def infer_shape_for_node(g, node):
141141
log.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
142142
return True
143143

144+
if node.type == "Gather":
145+
# uses the follwing link to know how to infer shape of output
146+
# https://www.tensorflow.org/api_docs/python/tf/gather
147+
shape_params = g.get_shape(node.input[0])
148+
shape_indices = g.get_shape(node.input[1])
149+
axis = node.input[2].get_tensor_value()
150+
151+
shape = shape_params[:axis] + shape_indices + shape_indices[axis+1:]
152+
g.set_shape(node.output[0], shape)
153+
return True
154+
144155
if node.type in ["All", "Any", "Min"]:
145156
axis_node = node.inputs[1]
146157
axis = axis_node.get_tensor_value()

0 commit comments

Comments
 (0)