@@ -149,8 +149,8 @@ def test_celu(self):
149
149
self .assertEqual (knn .celu (x ).shape , (None , 2 , 3 ))
150
150
151
151
def test_glu (self ):
152
- x = KerasTensor ([None , 2 , 3 ])
153
- self .assertEqual (knn .glu (x ).shape , (None , 2 , 3 ))
152
+ x = KerasTensor ([None , 2 , 4 ])
153
+ self .assertEqual (knn .glu (x ).shape , (None , 2 , 2 ))
154
154
155
155
def test_tanh_shrink (self ):
156
156
x = KerasTensor ([None , 2 , 3 ])
@@ -851,8 +851,8 @@ def test_celu(self):
851
851
self .assertEqual (knn .celu (x ).shape , (1 , 2 , 3 ))
852
852
853
853
def test_glu (self ):
854
- x = KerasTensor ([1 , 2 , 3 ])
855
- self .assertEqual (knn .glu (x ).shape , (1 , 2 , 3 ))
854
+ x = KerasTensor ([1 , 2 , 4 ])
855
+ self .assertEqual (knn .glu (x ).shape , (1 , 2 , 2 ))
856
856
857
857
def test_tanh_shrink (self ):
858
858
x = KerasTensor ([1 , 2 , 3 ])
@@ -2734,9 +2734,6 @@ def test_glu(self, dtype):
2734
2734
import jax .nn as jnn
2735
2735
import jax .numpy as jnp
2736
2736
2737
- if dtype == "bfloat16" :
2738
- self .skipTest ("Weirdness with numpy" )
2739
-
2740
2737
x = knp .ones ((2 ), dtype = dtype )
2741
2738
x_jax = jnp .ones ((2 ), dtype = dtype )
2742
2739
expected_dtype = standardize_dtype (jnn .glu (x_jax ).dtype )
0 commit comments