Skip to content

Commit 9aad343

Browse files
authored
Merge pull request #489 from lucienwang1009/rs_bug
fix custom rnn shape bug in opset 10
2 parents ceb6eb7 + a935a89 commit 9aad343

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ def _create_scan_node(self, context, scan_props, init_values):
112112
loop_outputs_dtypes = []
113113
for tensor_value_info in scan_props.state_outputs_exits + scan_props.scan_outputs_exits:
114114
if tensor_value_info.id:
115-
loop_outputs_shapes.append([1] + tensor_value_info.shape)
115+
# in opset 8, the first dim of scan output must be batch
116+
if self.g.opset == 8:
117+
loop_outputs_shapes.append([1] + tensor_value_info.shape)
118+
else:
119+
loop_outputs_shapes.append(tensor_value_info.shape)
116120
loop_outputs_dtypes.append(tensor_value_info.dtype)
117121
n = self.g.get_node_by_output(tensor_value_info.id)
118122
self.g.remove_node(n.name)

0 commit comments

Comments
 (0)