Skip to content

Commit 21cd5c0

Browse files
authored
Fix numpy tests for window ops, angle, identity, eye. (#21596)
JAX is introducing a change whereby some ops, in particular window ops, return a result of the default dtype for the current configuration (in particular based on the `enable_x64` flag). For cross-backend consistency, we want the window ops to return `floatx` regardless of how JAX is configured. Additionally, the unit tests need to be changed to not use the JAX returned dtype as the expected dtype.
1 parent 89a8676 commit 21cd5c0

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ def add(x1, x2):
3939

4040
def bartlett(x):
4141
x = convert_to_tensor(x)
42-
return jnp.bartlett(x)
42+
return cast(jnp.bartlett(x), config.floatx())
4343

4444

4545
def hamming(x):
4646
x = convert_to_tensor(x)
47-
return jnp.hamming(x)
47+
return cast(jnp.hamming(x), config.floatx())
4848

4949

5050
def hanning(x):
5151
x = convert_to_tensor(x)
52-
return jnp.hanning(x)
52+
return cast(jnp.hanning(x), config.floatx())
5353

5454

5555
def heaviside(x1, x2):
@@ -60,7 +60,7 @@ def heaviside(x1, x2):
6060

6161
def kaiser(x, beta):
6262
x = convert_to_tensor(x)
63-
return jnp.kaiser(x, beta)
63+
return cast(jnp.kaiser(x, beta), config.floatx())
6464

6565

6666
def bincount(x, weights=None, minlength=0, sparse=False):
@@ -497,7 +497,7 @@ def right_shift(x, y):
497497

498498
def blackman(x):
499499
x = convert_to_tensor(x)
500-
return jnp.blackman(x)
500+
return cast(jnp.blackman(x), config.floatx())
501501

502502

503503
def broadcast_to(x, shape):

keras/src/ops/numpy_test.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5765,11 +5765,8 @@ def test_add_python_types(self, dtype):
57655765

57665766
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
57675767
def test_bartlett(self, dtype):
5768-
import jax.numpy as jnp
5769-
57705768
x = knp.ones((), dtype=dtype)
5771-
x_jax = jnp.ones((), dtype=dtype)
5772-
expected_dtype = standardize_dtype(jnp.bartlett(x_jax).dtype)
5769+
expected_dtype = backend.floatx()
57735770

57745771
self.assertEqual(
57755772
standardize_dtype(knp.bartlett(x).dtype), expected_dtype
@@ -5781,11 +5778,8 @@ def test_bartlett(self, dtype):
57815778

57825779
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
57835780
def test_blackman(self, dtype):
5784-
import jax.numpy as jnp
5785-
57865781
x = knp.ones((), dtype=dtype)
5787-
x_jax = jnp.ones((), dtype=dtype)
5788-
expected_dtype = standardize_dtype(jnp.blackman(x_jax).dtype)
5782+
expected_dtype = backend.floatx()
57895783

57905784
self.assertEqual(
57915785
standardize_dtype(knp.blackman(x).dtype), expected_dtype
@@ -5797,11 +5791,8 @@ def test_blackman(self, dtype):
57975791

57985792
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
57995793
def test_hamming(self, dtype):
5800-
import jax.numpy as jnp
5801-
58025794
x = knp.ones((), dtype=dtype)
5803-
x_jax = jnp.ones((), dtype=dtype)
5804-
expected_dtype = standardize_dtype(jnp.hamming(x_jax).dtype)
5795+
expected_dtype = backend.floatx()
58055796

58065797
self.assertEqual(
58075798
standardize_dtype(knp.hamming(x).dtype), expected_dtype
@@ -5813,11 +5804,8 @@ def test_hamming(self, dtype):
58135804

58145805
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
58155806
def test_hanning(self, dtype):
5816-
import jax.numpy as jnp
5817-
58185807
x = knp.ones((), dtype=dtype)
5819-
x_jax = jnp.ones((), dtype=dtype)
5820-
expected_dtype = standardize_dtype(jnp.hanning(x_jax).dtype)
5808+
expected_dtype = backend.floatx()
58215809

58225810
self.assertEqual(
58235811
standardize_dtype(knp.hanning(x).dtype), expected_dtype
@@ -5829,13 +5817,9 @@ def test_hanning(self, dtype):
58295817

58305818
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
58315819
def test_kaiser(self, dtype):
5832-
import jax.numpy as jnp
5833-
58345820
x = knp.ones((), dtype=dtype)
58355821
beta = knp.ones((), dtype=dtype)
5836-
x_jax = jnp.ones((), dtype=dtype)
5837-
beta_jax = jnp.ones((), dtype=dtype)
5838-
expected_dtype = standardize_dtype(jnp.kaiser(x_jax, beta_jax).dtype)
5822+
expected_dtype = backend.floatx()
58395823

58405824
self.assertEqual(
58415825
standardize_dtype(knp.kaiser(x, beta).dtype), expected_dtype
@@ -7268,13 +7252,17 @@ def test_eye(self, dtype):
72687252
import jax.numpy as jnp
72697253

72707254
expected_dtype = standardize_dtype(jnp.eye(3, dtype=dtype).dtype)
7255+
if dtype is None:
7256+
expected_dtype = backend.floatx()
72717257

72727258
self.assertEqual(
72737259
standardize_dtype(knp.eye(3, dtype=dtype).dtype),
72747260
expected_dtype,
72757261
)
72767262

72777263
expected_dtype = standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype)
7264+
if dtype is None:
7265+
expected_dtype = backend.floatx()
72787266

72797267
self.assertEqual(
72807268
standardize_dtype(knp.eye(3, 4, k=1, dtype=dtype).dtype),
@@ -7506,6 +7494,8 @@ def test_identity(self, dtype):
75067494
import jax.numpy as jnp
75077495

75087496
expected_dtype = standardize_dtype(jnp.identity(3, dtype=dtype).dtype)
7497+
if dtype is None:
7498+
expected_dtype = backend.floatx()
75097499

75107500
self.assertEqual(
75117501
standardize_dtype(knp.identity(3, dtype=dtype).dtype),
@@ -9141,7 +9131,7 @@ def test_angle(self, dtype):
91419131
x = knp.ones((1,), dtype=dtype)
91429132
x_jax = jnp.ones((1,), dtype=dtype)
91439133
expected_dtype = standardize_dtype(jnp.angle(x_jax).dtype)
9144-
if dtype == "int64":
9134+
if dtype == "bool" or is_int_dtype(dtype):
91459135
expected_dtype = backend.floatx()
91469136

91479137
self.assertEqual(standardize_dtype(knp.angle(x).dtype), expected_dtype)

0 commit comments

Comments
 (0)