Skip to content

Commit d3cc8e9

Browse files
authored
added reshape to fix shape mismatch between flatten output and model … (#125)
* added reshape to fix shape mismatch between flatten output and model final output * remove flatten before reshape
1 parent 708d309 commit d3cc8e9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnxmltools/convert/keras/operator_converters/Bidirectional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def convert_bidirectional(scope, operator, container):
209209
transposed_y_name = scope.get_unique_variable_name(operator.full_name + '_Y_transposed')
210210
apply_transpose(scope, lstm_y_name_fixed, transposed_y_name, container, perm=[0, 2, 1, 3])
211211

212-
# Flatten ONNX (T, N, D, C') into (T, N, D * C')
213-
container.add_node('Flatten', transposed_y_name, operator.outputs[0].full_name,
214-
name=scope.get_unique_variable_name('Flatten'), axis=2)
212+
# Change shape (T, N, D, C') to (N, T, D * C') to meet Keras spec
213+
apply_reshape(scope, transposed_y_name, operator.outputs[0].full_name, container,
214+
desired_shape=[-1, seq_length, 2 * hidden_size])
215215
else:
216216
# If merge_mode=None, two tensors should be generated. The first/second tensor is the output of
217217
# forward/backward pass.

0 commit comments

Comments
 (0)