Skip to content

Commit 10e6ec5

Browse files
committed
Fix issue with shape of LSTM node
With the recent support of stacked LSTM, this issue is introduced. context.hidden_size is now a list. So for ith LSTM, we need to assign the ith value instead of entire list.
1 parent be2ccaf commit 10e6ec5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,9 @@ def create_single_rnn_node(self, context, i):
354354
out_dtype = self.g.get_dtype(lstm_inputs[0])
355355

356356
lstm_node = self.g.make_node("LSTM", lstm_inputs, attr=context.attributes[i], output_count=3,
357-
shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size],
358-
[num_direction, x_batch_size, context.hidden_size],
359-
[num_direction, x_batch_size, context.hidden_size]],
357+
shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size[i]],
358+
[num_direction, x_batch_size, context.hidden_size[i]],
359+
[num_direction, x_batch_size, context.hidden_size[i]]],
360360
dtypes=[out_dtype, out_dtype, out_dtype], op_name_scope=context.rnn_scope)
361361
return lstm_node
362362

0 commit comments

Comments
 (0)