Skip to content

Commit 98fdd0e

Browse files
authored
skip float64 for mlx (#19576)
1 parent 6424c61 commit 98fdd0e

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

keras/src/ops/nn_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,9 @@ def tearDown(self) -> None:
19701970

19711971
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
19721972
def test_elu(self, dtype):
1973+
if backend.backend() == "mlx" and dtype == "float64":
1974+
pytest.skip("Backend does not support float64")
1975+
19731976
import jax.nn as jnn
19741977
import jax.numpy as jnp
19751978

@@ -1988,6 +1991,9 @@ def test_elu(self, dtype):
19881991

19891992
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
19901993
def test_gelu(self, dtype):
1994+
if backend.backend() == "mlx" and dtype == "float64":
1995+
pytest.skip("Backend does not support float64")
1996+
19911997
import jax.nn as jnn
19921998
import jax.numpy as jnp
19931999

@@ -2019,6 +2025,9 @@ def test_gelu(self, dtype):
20192025

20202026
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
20212027
def test_hard_sigmoid(self, dtype):
2028+
if backend.backend() == "mlx" and dtype == "float64":
2029+
pytest.skip("Backend does not support float64")
2030+
20222031
import jax.nn as jnn
20232032
import jax.numpy as jnp
20242033

@@ -2037,6 +2046,9 @@ def test_hard_sigmoid(self, dtype):
20372046

20382047
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
20392048
def test_hard_silu(self, dtype):
2049+
if backend.backend() == "mlx" and dtype == "float64":
2050+
pytest.skip("Backend does not support float64")
2051+
20402052
import jax.nn as jnn
20412053
import jax.numpy as jnp
20422054

@@ -2055,6 +2067,9 @@ def test_hard_silu(self, dtype):
20552067

20562068
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
20572069
def test_leaky_relu(self, dtype):
2070+
if backend.backend() == "mlx" and dtype == "float64":
2071+
pytest.skip("Backend does not support float64")
2072+
20582073
import jax.nn as jnn
20592074
import jax.numpy as jnp
20602075

@@ -2073,6 +2088,9 @@ def test_leaky_relu(self, dtype):
20732088

20742089
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
20752090
def test_log_sigmoid(self, dtype):
2091+
if backend.backend() == "mlx" and dtype == "float64":
2092+
pytest.skip("Backend does not support float64")
2093+
20762094
import jax.nn as jnn
20772095
import jax.numpy as jnp
20782096

@@ -2091,6 +2109,9 @@ def test_log_sigmoid(self, dtype):
20912109

20922110
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
20932111
def test_log_softmax(self, dtype):
2112+
if backend.backend() == "mlx" and dtype == "float64":
2113+
pytest.skip("Backend does not support float64")
2114+
20942115
import jax.nn as jnn
20952116
import jax.numpy as jnp
20962117

@@ -2109,6 +2130,9 @@ def test_log_softmax(self, dtype):
21092130

21102131
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
21112132
def test_relu(self, dtype):
2133+
if backend.backend() == "mlx" and dtype == "float64":
2134+
pytest.skip("Backend does not support float64")
2135+
21122136
import jax.nn as jnn
21132137
import jax.numpy as jnp
21142138

@@ -2127,6 +2151,9 @@ def test_relu(self, dtype):
21272151

21282152
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
21292153
def test_relu6(self, dtype):
2154+
if backend.backend() == "mlx" and dtype == "float64":
2155+
pytest.skip("Backend does not support float64")
2156+
21302157
import jax.nn as jnn
21312158
import jax.numpy as jnp
21322159

@@ -2145,6 +2172,9 @@ def test_relu6(self, dtype):
21452172

21462173
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
21472174
def test_selu(self, dtype):
2175+
if backend.backend() == "mlx" and dtype == "float64":
2176+
pytest.skip("Backend does not support float64")
2177+
21482178
import jax.nn as jnn
21492179
import jax.numpy as jnp
21502180

@@ -2163,6 +2193,9 @@ def test_selu(self, dtype):
21632193

21642194
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
21652195
def test_sigmoid(self, dtype):
2196+
if backend.backend() == "mlx" and dtype == "float64":
2197+
pytest.skip("Backend does not support float64")
2198+
21662199
import jax.nn as jnn
21672200
import jax.numpy as jnp
21682201

@@ -2181,6 +2214,9 @@ def test_sigmoid(self, dtype):
21812214

21822215
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
21832216
def test_silu(self, dtype):
2217+
if backend.backend() == "mlx" and dtype == "float64":
2218+
pytest.skip("Backend does not support float64")
2219+
21842220
import jax.nn as jnn
21852221
import jax.numpy as jnp
21862222

@@ -2199,6 +2235,9 @@ def test_silu(self, dtype):
21992235

22002236
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
22012237
def test_softplus(self, dtype):
2238+
if backend.backend() == "mlx" and dtype == "float64":
2239+
pytest.skip("Backend does not support float64")
2240+
22022241
import jax.nn as jnn
22032242
import jax.numpy as jnp
22042243

@@ -2217,6 +2256,9 @@ def test_softplus(self, dtype):
22172256

22182257
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
22192258
def test_softmax(self, dtype):
2259+
if backend.backend() == "mlx" and dtype == "float64":
2260+
pytest.skip("Backend does not support float64")
2261+
22202262
import jax.nn as jnn
22212263
import jax.numpy as jnp
22222264

@@ -2235,6 +2277,9 @@ def test_softmax(self, dtype):
22352277

22362278
@parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))
22372279
def test_softsign(self, dtype):
2280+
if backend.backend() == "mlx" and dtype == "float64":
2281+
pytest.skip("Backend does not support float64")
2282+
22382283
import jax.nn as jnn
22392284
import jax.numpy as jnp
22402285

0 commit comments

Comments
 (0)