@@ -485,26 +485,42 @@ def find_bidirectional_rnns(g, ops, rnn_type):
485
485
input_id = n .input [0 ]
486
486
temp = n .inputs [0 ]
487
487
is_bw = False
488
+ is_transposed = False
488
489
if temp .type == "Transpose" :
489
490
input_id = temp .input [0 ]
490
491
temp = temp .inputs [0 ]
492
+ is_transposed = True
491
493
492
494
if utils .is_tf_reverse_op (temp ):
493
495
input_id = temp .input [0 ]
496
+ temp = temp .inputs [0 ]
494
497
is_bw = True
495
498
499
+ if (not is_transposed ) and temp .type == "Transpose" :
500
+ input_id = temp .input [0 ]
501
+ temp = temp .inputs [0 ]
502
+
503
+ input_ids = [input_id ]
504
+ if temp .type == "Identity" :
505
+ input_ids .append (temp .input [0 ])
506
+ temp = temp .inputs [0 ]
507
+ if temp .type == "Identity" :
508
+ input_ids .append (temp .input [0 ])
509
+
496
510
if is_bw :
497
511
# if output 0 is consumed and there is no reverse after the 1st output.
498
512
# it's not backward rnn.
499
- if g .find_output_consumers (n .output [0 ]) and not get_reverse_nodes_after_y_output (g , n ):
513
+ if g .find_output_consumers (n .output [0 ]) and not get_reverse_or_slice_nodes_after_y_output (g , n ):
500
514
logger .warning ("rnn %s following Reverse op isn't the part of bi-rnn." , n .name )
501
515
continue
502
516
503
- logger .debug ("find bw rnn %s" , input_id )
504
- bw_rnns [input_id ].append (n )
517
+ logger .debug ("find bw rnn %s" , input_ids )
518
+ for input_id in input_ids :
519
+ bw_rnns [input_id ].append (n )
505
520
else :
506
- logger .debug ("find fw rnn %s" , input_id )
507
- fw_rnns [input_id ].append (n )
521
+ logger .debug ("find fw rnn %s" , input_ids )
522
+ for input_id in input_ids :
523
+ fw_rnns [input_id ].append (n )
508
524
509
525
# fw_rnn and bw_rnn must share the same input
510
526
birnn_input = list (set (fw_rnns .keys ()).intersection (bw_rnns .keys ()))
@@ -554,27 +570,40 @@ def belong_to_birnn(g, fw_rnn, bw_rnn, rnn_type):
554
570
return True
555
571
556
572
557
- def get_reverse_nodes_after_y_output (g , rnn_bw ):
573
+ def is_tail_slice_op (node ):
574
+ return (
575
+ node .type == 'StridedSlice' and
576
+ node .inputs [1 ].get_tensor_value () == [- 1 ] and
577
+ node .inputs [2 ].get_tensor_value () == [0 ] and
578
+ node .inputs [3 ].get_tensor_value () == [1 ] and
579
+ node .get_attr ('shrink_axis_mask' ).i == 1
580
+ )
581
+
582
+
583
+ def get_reverse_or_slice_nodes_after_y_output (g , rnn_bw ):
558
584
bw_consumers = g .find_output_consumers (rnn_bw .output [0 ])
559
585
560
586
# todo: figure out a better way to remove reverse op
561
587
squeeze_nodes = [c for c in bw_consumers if c .type == "Squeeze" ]
562
588
s_cnt = len (squeeze_nodes )
563
589
if s_cnt == 1 :
564
590
s = squeeze_nodes [0 ]
565
- trans_nodes = g .find_output_consumers (s .output [0 ])
566
- if len (trans_nodes ) == 1 :
567
- if trans_nodes [0 ].type == "Transpose" :
568
- reverse_nodes = g .find_output_consumers (trans_nodes [0 ].output [0 ])
569
- elif utils .is_tf_reverse_op (trans_nodes [0 ]):
570
- reverse_nodes = trans_nodes
571
- else :
572
- logger .debug ("not found reverse op, unexpected" )
573
- return []
574
-
575
- are_all_reverse = all ([utils .is_tf_reverse_op (r_op ) for r_op in reverse_nodes ])
576
- if are_all_reverse :
577
- return reverse_nodes
591
+ reverse_or_slice_nodes = g .find_output_consumers (s .output [0 ])
592
+ if len (reverse_or_slice_nodes ) == 1 :
593
+ if reverse_or_slice_nodes [0 ].type == "Transpose" :
594
+ reverse_or_slice_nodes = g .find_output_consumers (reverse_or_slice_nodes [0 ].output [0 ])
595
+
596
+ if len (reverse_or_slice_nodes ) == 1 and reverse_or_slice_nodes [0 ].type == "Identity" :
597
+ reverse_or_slice_nodes = g .find_output_consumers (reverse_or_slice_nodes [0 ].output [0 ])
598
+ if len (reverse_or_slice_nodes ) == 1 and reverse_or_slice_nodes [0 ].type == "Identity" :
599
+ reverse_or_slice_nodes = g .find_output_consumers (reverse_or_slice_nodes [0 ].output [0 ])
600
+
601
+ are_all_reverse_or_slice = all ([
602
+ utils .is_tf_reverse_op (r_op ) or is_tail_slice_op (r_op )
603
+ for r_op in reverse_or_slice_nodes
604
+ ])
605
+ if are_all_reverse_or_slice :
606
+ return reverse_or_slice_nodes
578
607
579
608
logger .debug ("bw y output is used followed by reverse node" )
580
609
return []
@@ -619,13 +648,28 @@ def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output
619
648
620
649
if rnn_output_index == 0 :
621
650
axis = 1
622
- # remove reverse op for rnn_bw
623
- reverse_nodes = get_reverse_nodes_after_y_output (g , rnn_bw )
624
-
625
- for r_op in reverse_nodes :
626
- logger .debug ("remove reverse op %s" , r_op .name )
627
- g .replace_all_inputs (r_op .output [0 ], r_op .input [0 ], ops = all_nodes )
628
- to_remove .append (r_op .name )
651
+ # remove reverse(return_sequence=True) or tail slice(return_sequence=False) op for rnn_bw
652
+ reverse_or_slice_nodes = get_reverse_or_slice_nodes_after_y_output (g , rnn_bw )
653
+
654
+ for r_op in reverse_or_slice_nodes :
655
+ if utils .is_tf_reverse_op (r_op ):
656
+ logger .debug ("remove reverse op %s" , r_op .name )
657
+ g .replace_all_inputs (r_op .output [0 ], r_op .input [0 ], ops = all_nodes )
658
+ to_remove .append (r_op .name )
659
+ elif is_tail_slice_op (r_op ):
660
+ # in case of return_sequence=False
661
+ # replace output[-1:] to output[0:1]
662
+ attr = {"axes" : [0 ], "starts" : [0 ], "ends" : [1 ]}
663
+ inputs_map = {"data" : r_op .input [0 ], ** attr }
664
+ slice_node_bw = GraphBuilder (g ).make_slice (inputs_map )
665
+ all_nodes .append (g .get_node_by_output (slice_node_bw ))
666
+
667
+ inputs_map = {"data" : slice_node_bw , "axes" : [0 ]}
668
+ squeeze_node_bw = GraphBuilder (g ).make_squeeze (inputs_map )
669
+ all_nodes .append (g .get_node_by_output (squeeze_node_bw ))
670
+
671
+ g .replace_all_inputs (r_op .output [0 ], squeeze_node_bw , ops = all_nodes )
672
+ to_remove .append (r_op .name )
629
673
elif rnn_output_index in [1 , 2 ]:
630
674
axis = 0
631
675
else :
0 commit comments