File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -141,6 +141,17 @@ def infer_shape_for_node(g, node):
141
141
log .debug ("set ConcatV2 node [%s] with new shape %s" , node .output [0 ], new_shape )
142
142
return True
143
143
144
+ if node .type == "Gather" :
145
+ # uses the follwing link to know how to infer shape of output
146
+ # https://www.tensorflow.org/api_docs/python/tf/gather
147
+ shape_params = g .get_shape (node .input [0 ])
148
+ shape_indices = g .get_shape (node .input [1 ])
149
+ axis = node .input [2 ].get_tensor_value ()
150
+
151
+ shape = shape_params [:axis ] + shape_indices + shape_indices [axis + 1 :]
152
+ g .set_shape (node .output [0 ], shape )
153
+ return True
154
+
144
155
if node .type in ["All" , "Any" , "Min" ]:
145
156
axis_node = node .inputs [1 ]
146
157
axis = axis_node .get_tensor_value ()
You can’t perform that action at this time.
0 commit comments