Skip to content

Commit b062368

Browse files
Fix the shape of KerasTensor for glu. (#21696)
1 parent 26d7166 commit b062368

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

keras/src/backend/numpy/nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ def celu(x, alpha=1.0):
164164

165165
def glu(x, axis=-1):
166166
x = convert_to_tensor(x)
167+
dtype = x.dtype
167168
if x.shape[axis] % 2 != 0:
168169
raise ValueError(
169170
"axis size must be divisible by 2. "
170171
f"Received: x.shape={x.shape} with axis={axis}"
171172
)
172173
x1, x2 = np.split(x, 2, axis)
173-
return x1 * (1 / (1 + np.exp(-x2)))
174+
return (x1 * sigmoid(x2)).astype(dtype)
174175

175176

176177
def hard_tanh(x):

keras/src/ops/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,15 @@ def call(self, x):
704704
return backend.nn.glu(x, axis=self.axis)
705705

706706
def compute_output_spec(self, x):
707-
return KerasTensor(x.shape, dtype=x.dtype)
707+
output_shape = list(x.shape)
708+
if output_shape[self.axis] is not None:
709+
if output_shape[self.axis] % 2 != 0:
710+
raise ValueError(
711+
"axis size must be divisible by 2. "
712+
f"Received: x.shape={x.shape} with axis={self.axis}"
713+
)
714+
output_shape[self.axis] = output_shape[self.axis] // 2
715+
return KerasTensor(output_shape, dtype=x.dtype)
708716

709717

710718
@keras_export(["keras.ops.glu", "keras.ops.nn.glu"])

keras/src/ops/nn_test.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def test_celu(self):
149149
self.assertEqual(knn.celu(x).shape, (None, 2, 3))
150150

151151
def test_glu(self):
152-
x = KerasTensor([None, 2, 3])
153-
self.assertEqual(knn.glu(x).shape, (None, 2, 3))
152+
x = KerasTensor([None, 2, 4])
153+
self.assertEqual(knn.glu(x).shape, (None, 2, 2))
154154

155155
def test_tanh_shrink(self):
156156
x = KerasTensor([None, 2, 3])
@@ -851,8 +851,8 @@ def test_celu(self):
851851
self.assertEqual(knn.celu(x).shape, (1, 2, 3))
852852

853853
def test_glu(self):
854-
x = KerasTensor([1, 2, 3])
855-
self.assertEqual(knn.glu(x).shape, (1, 2, 3))
854+
x = KerasTensor([1, 2, 4])
855+
self.assertEqual(knn.glu(x).shape, (1, 2, 2))
856856

857857
def test_tanh_shrink(self):
858858
x = KerasTensor([1, 2, 3])
@@ -2734,9 +2734,6 @@ def test_glu(self, dtype):
27342734
import jax.nn as jnn
27352735
import jax.numpy as jnp
27362736

2737-
if dtype == "bfloat16":
2738-
self.skipTest("Weirdness with numpy")
2739-
27402737
x = knp.ones((2), dtype=dtype)
27412738
x_jax = jnp.ones((2), dtype=dtype)
27422739
expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype)

0 commit comments

Comments
 (0)