Skip to content

Commit de7e5af

Browse files
kota-rowfatcat-z
andauthored
Fix keras bidirectional merge failures (#1869)
* add test to check keras bidirectional recurrent is merged Signed-off-by: Kotaro Yamamoto <[email protected]> * fix keras bidirectional merge failures support below cases: - there are one or two Identity layers between input/output and RNN - Transpose-Reverse-backward (previously, only Reverse-Transpose-backward was supported) - return_sequences=False with no Reverse after the backward Signed-off-by: Kotaro Yamamoto <[email protected]> * apply review comments for Bidirectional fix - Consecutive Identity checks changed to nested if - update comment for remove Reverse or tail-slice op Signed-off-by: Kotaro Yamamoto <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 0e3720c commit de7e5af

File tree

3 files changed

+88
-27
lines changed

3 files changed

+88
-27
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mock_keras2onnx.proto import (keras, is_tf_keras,
88
is_tensorflow_older_than, is_tensorflow_later_than,
99
is_keras_older_than, is_keras_later_than)
10-
from test_utils import no_loops_in_tf2
10+
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional
1111

1212
K = keras.backend
1313
Activation = keras.layers.Activation
@@ -2073,6 +2073,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
20732073
model.add(Activation('softmax'))
20742074
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
20752075
onnx_model = convert_keras(model, 'test', target_opset=op_version)
2076+
assert all_recurrents_should_bidirectional(onnx_model)
20762077
for batch in batch_list:
20772078
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
20782079
expected = model.predict(data)
@@ -2084,6 +2085,7 @@ def test_bidirectional(runner, rnn_class, return_sequences):
20842085
input_shape=(5, 10), merge_mode=merge_mode)(sub_input1)
20852086
keras_model = keras.Model(inputs=sub_input1, outputs=sub_mapped1)
20862087
onnx_model = convert_keras(keras_model, 'test_2', target_opset=op_version)
2088+
assert all_recurrents_should_bidirectional(onnx_model)
20872089
for batch in batch_list:
20882090
data = np.random.rand(batch, sequence_len, input_dim).astype(np.float32)
20892091
expected = keras_model.predict(data)
@@ -2102,6 +2104,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
21022104
# Test with the default bias
21032105
expected = model.predict(x)
21042106
onnx_model = convert_keras(model, model.name)
2107+
assert all_recurrents_should_bidirectional(onnx_model)
21052108
assert runner(onnx_model.graph.name, onnx_model, x, expected)
21062109

21072110
# Set bias values to random floats
@@ -2114,6 +2117,7 @@ def test_bidirectional_with_bias(runner, rnn_class):
21142117
# Test with random bias
21152118
expected = model.predict(x)
21162119
onnx_model = convert_keras(model, model.name)
2120+
assert all_recurrents_should_bidirectional(onnx_model)
21172121
assert runner(onnx_model.graph.name, onnx_model, x, expected)
21182122

21192123

@@ -2141,6 +2145,7 @@ def test_bidirectional_time_major_true(runner, rnn_class):
21412145

21422146
expected = model.predict(x)
21432147
onnx_model = convert_keras(model, model.name)
2148+
assert all_recurrents_should_bidirectional(onnx_model)
21442149
assert runner(onnx_model.graph.name, onnx_model, x, expected)
21452150

21462151

@@ -2155,6 +2160,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):
21552160

21562161
expected = model.predict(inputs)
21572162
onnx_model = convert_keras(model, model.name)
2163+
assert all_recurrents_should_bidirectional(onnx_model)
21582164
assert runner(onnx_model.graph.name, onnx_model, inputs, expected)
21592165

21602166
input2 = Input(shape=(None, 5))
@@ -2165,6 +2171,7 @@ def test_bidirectional_with_initial_states(runner, rnn_class):
21652171

21662172
expected = model.predict(inputs)
21672173
onnx_model = convert_keras(model, model.name)
2174+
assert all_recurrents_should_bidirectional(onnx_model)
21682175
assert runner(onnx_model.graph.name, onnx_model, inputs, expected, atol=1e-5)
21692176

21702177

@@ -2180,6 +2187,7 @@ def test_bidirectional_seqlen_none(runner, rnn_class):
21802187
model.add(Dense(44))
21812188

21822189
onnx_model = convert_keras(model, model.name)
2190+
assert all_recurrents_should_bidirectional(onnx_model)
21832191
for batch in [1, 4]:
21842192
x = np.random.rand(batch, 50).astype(np.float32)
21852193
expected = model.predict(x)

tests/keras2onnx_unit_tests/test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import onnx
6+
from onnx import helper
67
import numpy as np
78
import mock_keras2onnx
89
from mock_keras2onnx.proto import keras, is_keras_older_than
@@ -161,6 +162,14 @@ def no_loops_in_tf2(onnx_model):
161162
return not is_tf2 or all(n.op_type != "Loop" for n in onnx_model.graph.node)
162163

163164

165+
def all_recurrents_should_bidirectional(onnx_model):
166+
return all([
167+
helper.get_attribute_value(attr) == b'bidirectional'
168+
for node in onnx_model.graph.node if node.op_type in ['GRU', 'LSTM', 'RNN']
169+
for attr in node.attribute if attr.name == 'direction'
170+
])
171+
172+
164173
def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.e-3, atol=1.e-6,
165174
compare_perf=False, enable_profiling=False):
166175
if not os.path.exists(tmp_path):

tf2onnx/rewriter/rnn_utils.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -485,26 +485,42 @@ def find_bidirectional_rnns(g, ops, rnn_type):
485485
input_id = n.input[0]
486486
temp = n.inputs[0]
487487
is_bw = False
488+
is_transposed = False
488489
if temp.type == "Transpose":
489490
input_id = temp.input[0]
490491
temp = temp.inputs[0]
492+
is_transposed = True
491493

492494
if utils.is_tf_reverse_op(temp):
493495
input_id = temp.input[0]
496+
temp = temp.inputs[0]
494497
is_bw = True
495498

499+
if (not is_transposed) and temp.type == "Transpose":
500+
input_id = temp.input[0]
501+
temp = temp.inputs[0]
502+
503+
input_ids = [input_id]
504+
if temp.type == "Identity":
505+
input_ids.append(temp.input[0])
506+
temp = temp.inputs[0]
507+
if temp.type == "Identity":
508+
input_ids.append(temp.input[0])
509+
496510
if is_bw:
497511
# if output 0 is consumed and there is no reverse after the 1st output.
498512
# it's not backward rnn.
499-
if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
513+
if g.find_output_consumers(n.output[0]) and not get_reverse_or_slice_nodes_after_y_output(g, n):
500514
logger.warning("rnn %s following Reverse op isn't the part of bi-rnn.", n.name)
501515
continue
502516

503-
logger.debug("find bw rnn %s", input_id)
504-
bw_rnns[input_id].append(n)
517+
logger.debug("find bw rnn %s", input_ids)
518+
for input_id in input_ids:
519+
bw_rnns[input_id].append(n)
505520
else:
506-
logger.debug("find fw rnn %s", input_id)
507-
fw_rnns[input_id].append(n)
521+
logger.debug("find fw rnn %s", input_ids)
522+
for input_id in input_ids:
523+
fw_rnns[input_id].append(n)
508524

509525
# fw_rnn and bw_rnn must share the same input
510526
birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys()))
@@ -554,27 +570,40 @@ def belong_to_birnn(g, fw_rnn, bw_rnn, rnn_type):
554570
return True
555571

556572

557-
def get_reverse_nodes_after_y_output(g, rnn_bw):
573+
def is_tail_slice_op(node):
574+
return (
575+
node.type == 'StridedSlice' and
576+
node.inputs[1].get_tensor_value() == [-1] and
577+
node.inputs[2].get_tensor_value() == [0] and
578+
node.inputs[3].get_tensor_value() == [1] and
579+
node.get_attr('shrink_axis_mask').i == 1
580+
)
581+
582+
583+
def get_reverse_or_slice_nodes_after_y_output(g, rnn_bw):
558584
bw_consumers = g.find_output_consumers(rnn_bw.output[0])
559585

560586
# todo: figure out a better way to remove reverse op
561587
squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
562588
s_cnt = len(squeeze_nodes)
563589
if s_cnt == 1:
564590
s = squeeze_nodes[0]
565-
trans_nodes = g.find_output_consumers(s.output[0])
566-
if len(trans_nodes) == 1:
567-
if trans_nodes[0].type == "Transpose":
568-
reverse_nodes = g.find_output_consumers(trans_nodes[0].output[0])
569-
elif utils.is_tf_reverse_op(trans_nodes[0]):
570-
reverse_nodes = trans_nodes
571-
else:
572-
logger.debug("not found reverse op, unexpected")
573-
return []
574-
575-
are_all_reverse = all([utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes])
576-
if are_all_reverse:
577-
return reverse_nodes
591+
reverse_or_slice_nodes = g.find_output_consumers(s.output[0])
592+
if len(reverse_or_slice_nodes) == 1:
593+
if reverse_or_slice_nodes[0].type == "Transpose":
594+
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
595+
596+
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
597+
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
598+
if len(reverse_or_slice_nodes) == 1 and reverse_or_slice_nodes[0].type == "Identity":
599+
reverse_or_slice_nodes = g.find_output_consumers(reverse_or_slice_nodes[0].output[0])
600+
601+
are_all_reverse_or_slice = all([
602+
utils.is_tf_reverse_op(r_op) or is_tail_slice_op(r_op)
603+
for r_op in reverse_or_slice_nodes
604+
])
605+
if are_all_reverse_or_slice:
606+
return reverse_or_slice_nodes
578607

579608
logger.debug("bw y output is used followed by reverse node")
580609
return []
@@ -619,13 +648,28 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output
619648

620649
if rnn_output_index == 0:
621650
axis = 1
622-
# remove reverse op for rnn_bw
623-
reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)
624-
625-
for r_op in reverse_nodes:
626-
logger.debug("remove reverse op %s", r_op.name)
627-
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
628-
to_remove.append(r_op.name)
651+
# remove reverse(return_sequence=True) or tail slice(return_sequence=False) op for rnn_bw
652+
reverse_or_slice_nodes = get_reverse_or_slice_nodes_after_y_output(g, rnn_bw)
653+
654+
for r_op in reverse_or_slice_nodes:
655+
if utils.is_tf_reverse_op(r_op):
656+
logger.debug("remove reverse op %s", r_op.name)
657+
g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes)
658+
to_remove.append(r_op.name)
659+
elif is_tail_slice_op(r_op):
660+
# in case of return_sequence=False
661+
# replace output[-1:] to output[0:1]
662+
attr = {"axes": [0], "starts": [0], "ends": [1]}
663+
inputs_map = {"data": r_op.input[0], **attr}
664+
slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
665+
all_nodes.append(g.get_node_by_output(slice_node_bw))
666+
667+
inputs_map = {"data": slice_node_bw, "axes": [0]}
668+
squeeze_node_bw = GraphBuilder(g).make_squeeze(inputs_map)
669+
all_nodes.append(g.get_node_by_output(squeeze_node_bw))
670+
671+
g.replace_all_inputs(r_op.output[0], squeeze_node_bw, ops=all_nodes)
672+
to_remove.append(r_op.name)
629673
elif rnn_output_index in [1, 2]:
630674
axis = 0
631675
else:

0 commit comments

Comments
 (0)