Skip to content

Commit 1e5e377

Browse files
committed
fix bug of shape_inference when node is TensorArrayGatherV3
TensorArrayGatherV3 node should get shape from TensorArrayV3 instead of input "flow_in", shape of TensorArrayV3 can be got from "value" which is the input of TensorArrayWriteV3.
1 parent 4e7b167 commit 1e5e377

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

tf2onnx/shape_inference.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,32 @@ def infer_output_shapes_with_partial_inputs(g, node):
221221
return True
222222

223223
if node.type == "TensorArrayGatherV3":
224-
# TensorArrayGatherV3's output: all of the elements in the TensorArray,
225-
# concatenated along a new axis (the new dimension 0)
226-
flow_in_node = node.inputs[2]
227-
if flow_in_node.type != "Exit":
224+
# TensorArrayGatherV3's output: all of the elem in the TensorArray,
225+
# concatenated along a new axis (the new dimension 0), so shape of TensorArray should be found first.
226+
# And TensorArrayWrite will write elem to TensorArray, so shape of TensorArray can be got from TensorArrayWrite
227+
# so the process is: first find TensorArrayWrite and then get TensorArray's shape,
228+
# and finally add one dim to the shape is shape of TensorArrayGather
229+
230+
handle_node = node.inputs[0]
231+
if handle_node.type != "TensorArrayV3":
228232
return False
229233

230-
shape = g.get_shape(flow_in_node.output[0])
234+
# find TensorArrayWrite
235+
tensor_array_consumers = g.find_output_consumers(handle_node.output[0])
236+
tensor_array_write_found = False
237+
for i in tensor_array_consumers:
238+
if tensor_array_write_found:
239+
break
240+
consumer_nodes = g.find_output_consumers(i.output[0])
241+
for j in consumer_nodes:
242+
if i.type == "Enter" and j.type == "TensorArrayWriteV3":
243+
tensor_array_write_node = j
244+
tensor_array_write_found = True
245+
break
246+
# get TensorArray shape from input tensor of the found TensorArrayWrite node
247+
value_node = tensor_array_write_node.inputs[2]
248+
shape = g.get_shape(value_node.output[0])
249+
# update TensorArray's shape info
231250
if shape is not None:
232251
new_shape = [-1] + shape
233252
g.set_shape(node.output[0], new_shape)

0 commit comments

Comments
 (0)