@@ -775,6 +775,30 @@ def func(x):
775
775
return tf .identity (y [0 ], name = "output" ), tf .identity (y [1 ], name = "output1" )
776
776
self .run_test_case (func , {"input:0" : x_val }, [], ["output:0" , "output1:0" ], rtol = 1e-05 , atol = 1e-06 )
777
777
778
+ @check_tf_min_version ("2.0" )
779
+ def test_keras_bilstm_recurrent_activation_is_hard_sigmoid (self ):
780
+ in_shape = [10 , 3 ]
781
+ x_val = np .random .uniform (size = [2 , 10 , 3 ]).astype (np .float32 )
782
+
783
+ model_in = tf .keras .layers .Input (tuple (in_shape ), batch_size = 2 )
784
+ x = tf .keras .layers .Bidirectional (
785
+ tf .keras .layers .LSTM (
786
+ units = 5 ,
787
+ return_sequences = True ,
788
+ return_state = True ,
789
+ kernel_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 42 ),
790
+ recurrent_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 44 ),
791
+ bias_initializer = tf .random_uniform_initializer (0.0 , 1.0 , seed = 43 ),
792
+ recurrent_activation = "hard_sigmoid" ,
793
+ )
794
+ )(model_in )
795
+ model = tf .keras .models .Model (inputs = model_in , outputs = x )
796
+
797
+ def func (x ):
798
+ y = model (x )
799
+ return tf .identity (y [0 ], name = "output" ), tf .identity (y [1 ], name = "output1" )
800
+ self .run_test_case (func , {"input:0" : x_val }, [], ["output:0" , "output1:0" ], rtol = 1e-05 , atol = 1e-06 )
801
+
778
802
@check_tf_min_version ("2.0" )
779
803
@skip_tfjs ("TFJS converts model incorrectly" )
780
804
def test_keras_lstm_sigmoid_dropout (self ):
0 commit comments