Skip to content

Commit d630199

Browse files
authored
various numpy_test changes (#21020)
* various numpy_test changes * Applied review changes
1 parent dc9ed28 commit d630199

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

keras/src/ops/numpy_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,10 @@ def test_argmax(self):
11461146
keras.config.backend() == "openvino",
11471147
reason="OpenVINO doesn't support this change",
11481148
)
1149+
@pytest.mark.skipif(
1150+
keras.config.backend() == "mlx",
1151+
reason="Wrong results due to MLX flushing denormal numbers to 0 on GPU",
1152+
)
11491153
def test_argmax_negative_zero(self):
11501154
input_data = np.array(
11511155
[-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32
@@ -1161,6 +1165,10 @@ def test_argmax_negative_zero(self):
11611165
evaluation and may change within this PR
11621166
""",
11631167
)
1168+
@pytest.mark.skipif(
1169+
keras.config.backend() == "mlx",
1170+
reason="Wrong results due to MLX flushing denormal numbers to 0 on GPU",
1171+
)
11641172
def test_argmin_negative_zero(self):
11651173
input_data = np.array(
11661174
[
@@ -5391,10 +5399,16 @@ def setUp(self):
53915399

53925400
self.jax_enable_x64 = enable_x64()
53935401
self.jax_enable_x64.__enter__()
5402+
5403+
if backend.backend() == "mlx":
5404+
self.mlx_cpu_context = backend.core.enable_float64()
5405+
self.mlx_cpu_context.__enter__()
53945406
return super().setUp()
53955407

53965408
def tearDown(self):
53975409
self.jax_enable_x64.__exit__(None, None, None)
5410+
if backend.backend() == "mlx":
5411+
self.mlx_cpu_context.__exit__(None, None, None)
53985412
return super().tearDown()
53995413

54005414
@parameterized.named_parameters(
@@ -5598,6 +5612,13 @@ def test_matmul(self, dtypes):
55985612
import jax.numpy as jnp
55995613

56005614
dtype1, dtype2 = dtypes
5615+
if (
5616+
all(dtype not in self.FLOAT_DTYPES for dtype in dtypes)
5617+
and backend.backend() == "mlx"
5618+
):
5619+
# This must be removed once mlx.core.matmul supports integer dtypes
5620+
self.skipTest("mlx doesn't support integer dot product")
5621+
56015622
# The shape of the matrix needs to meet the requirements of
56025623
# torch._int_mm to test hardware-accelerated matmul
56035624
x1 = knp.ones((17, 16), dtype=dtype1)
@@ -6620,6 +6641,13 @@ def test_dot(self, dtypes):
66206641
import jax.numpy as jnp
66216642

66226643
dtype1, dtype2 = dtypes
6644+
if (
6645+
all(dtype not in self.FLOAT_DTYPES for dtype in dtypes)
6646+
and backend.backend() == "mlx"
6647+
):
6648+
# This must be removed once mlx.core.matmul supports integer dtypes
6649+
self.skipTest("mlx doesn't support integer dot product")
6650+
66236651
x1 = knp.ones((2, 3, 4), dtype=dtype1)
66246652
x2 = knp.ones((4, 3), dtype=dtype2)
66256653
x1_jax = jnp.ones((2, 3, 4), dtype=dtype1)
@@ -6648,6 +6676,13 @@ def get_input_shapes(subscripts):
66486676
return x1_shape, x2_shape
66496677

66506678
dtype1, dtype2 = dtypes
6679+
if (
6680+
all(dtype not in self.FLOAT_DTYPES for dtype in dtypes)
6681+
and backend.backend() == "mlx"
6682+
):
6683+
# This must be removed once mlx.core.matmul supports integer dtypes
6684+
self.skipTest("mlx doesn't support integer dot product")
6685+
66516686
subscripts = "ijk,lkj->il"
66526687
x1_shape, x2_shape = get_input_shapes(subscripts)
66536688
x1 = knp.ones(x1_shape, dtype=dtype1)
@@ -8312,6 +8347,13 @@ def test_tensordot(self, dtypes):
83128347
import jax.numpy as jnp
83138348

83148349
dtype1, dtype2 = dtypes
8350+
if (
8351+
all(dtype not in self.FLOAT_DTYPES for dtype in dtypes)
8352+
and backend.backend() == "mlx"
8353+
):
8354+
# This must be removed once mlx.core.matmul supports integer dtypes
8355+
self.skipTest("mlx doesn't support integer dot product")
8356+
83158357
x1 = knp.ones((1, 1), dtype=dtype1)
83168358
x2 = knp.ones((1, 1), dtype=dtype2)
83178359
x1_jax = jnp.ones((1, 1), dtype=dtype1)
@@ -8522,6 +8564,13 @@ def test_inner(self, dtypes):
85228564
import jax.numpy as jnp
85238565

85248566
dtype1, dtype2 = dtypes
8567+
if (
8568+
all(dtype not in self.FLOAT_DTYPES for dtype in dtypes)
8569+
and backend.backend() == "mlx"
8570+
):
8571+
# This must be removed once mlx.core.matmul supports integer dtypes
8572+
self.skipTest("mlx doesn't support integer dot product")
8573+
85258574
x1 = knp.ones((1,), dtype=dtype1)
85268575
x2 = knp.ones((1,), dtype=dtype2)
85278576
x1_jax = jnp.ones((1,), dtype=dtype1)

0 commit comments

Comments
 (0)