@@ -674,7 +674,8 @@ def func(x):
674
674
feed_dict = {"input_1:0" : x_val }
675
675
input_names_with_port = ["input_1:0" ]
676
676
output_names_with_port = ["output_1:0" , "cell_state_1:0" , "output_2:0" , "cell_state_2:0" ]
677
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
677
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
678
+ require_lstm_count = 2 )
678
679
679
680
@check_opset_after_tf_version ("1.15" , 10 , "might need ReverseV2" )
680
681
@skip_tf_versions ("2.1" , "Bug in TF 2.1" )
@@ -721,7 +722,8 @@ def func(x, y1, y2):
721
722
feed_dict = {"input_1:0" : x_val , "input_2:0" : seq_len_val , "input_3:0" : seq_len_val }
722
723
input_names_with_port = ["input_1:0" , "input_2:0" , "input_3:0" ]
723
724
output_names_with_port = ["output_1:0" , "cell_state_1:0" , "output_2:0" , "cell_state_2:0" ]
724
- self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 )
725
+ self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port , rtol = 1e-3 , atol = 1e-06 ,
726
+ require_lstm_count = 2 )
725
727
726
728
727
729
if __name__ == '__main__' :
0 commit comments