@@ -1970,6 +1970,9 @@ def tearDown(self) -> None:
1970
1970
1971
1971
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
1972
1972
def test_elu (self , dtype ):
1973
+ if backend .backend () == "mlx" and dtype == "float64" :
1974
+ pytest .skip ("Backend does not support float64" )
1975
+
1973
1976
import jax .nn as jnn
1974
1977
import jax .numpy as jnp
1975
1978
@@ -1988,6 +1991,9 @@ def test_elu(self, dtype):
1988
1991
1989
1992
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
1990
1993
def test_gelu (self , dtype ):
1994
+ if backend .backend () == "mlx" and dtype == "float64" :
1995
+ pytest .skip ("Backend does not support float64" )
1996
+
1991
1997
import jax .nn as jnn
1992
1998
import jax .numpy as jnp
1993
1999
@@ -2019,6 +2025,9 @@ def test_gelu(self, dtype):
2019
2025
2020
2026
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2021
2027
def test_hard_sigmoid (self , dtype ):
2028
+ if backend .backend () == "mlx" and dtype == "float64" :
2029
+ pytest .skip ("Backend does not support float64" )
2030
+
2022
2031
import jax .nn as jnn
2023
2032
import jax .numpy as jnp
2024
2033
@@ -2037,6 +2046,9 @@ def test_hard_sigmoid(self, dtype):
2037
2046
2038
2047
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2039
2048
def test_hard_silu (self , dtype ):
2049
+ if backend .backend () == "mlx" and dtype == "float64" :
2050
+ pytest .skip ("Backend does not support float64" )
2051
+
2040
2052
import jax .nn as jnn
2041
2053
import jax .numpy as jnp
2042
2054
@@ -2055,6 +2067,9 @@ def test_hard_silu(self, dtype):
2055
2067
2056
2068
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2057
2069
def test_leaky_relu (self , dtype ):
2070
+ if backend .backend () == "mlx" and dtype == "float64" :
2071
+ pytest .skip ("Backend does not support float64" )
2072
+
2058
2073
import jax .nn as jnn
2059
2074
import jax .numpy as jnp
2060
2075
@@ -2073,6 +2088,9 @@ def test_leaky_relu(self, dtype):
2073
2088
2074
2089
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2075
2090
def test_log_sigmoid (self , dtype ):
2091
+ if backend .backend () == "mlx" and dtype == "float64" :
2092
+ pytest .skip ("Backend does not support float64" )
2093
+
2076
2094
import jax .nn as jnn
2077
2095
import jax .numpy as jnp
2078
2096
@@ -2091,6 +2109,9 @@ def test_log_sigmoid(self, dtype):
2091
2109
2092
2110
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2093
2111
def test_log_softmax (self , dtype ):
2112
+ if backend .backend () == "mlx" and dtype == "float64" :
2113
+ pytest .skip ("Backend does not support float64" )
2114
+
2094
2115
import jax .nn as jnn
2095
2116
import jax .numpy as jnp
2096
2117
@@ -2109,6 +2130,9 @@ def test_log_softmax(self, dtype):
2109
2130
2110
2131
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2111
2132
def test_relu (self , dtype ):
2133
+ if backend .backend () == "mlx" and dtype == "float64" :
2134
+ pytest .skip ("Backend does not support float64" )
2135
+
2112
2136
import jax .nn as jnn
2113
2137
import jax .numpy as jnp
2114
2138
@@ -2127,6 +2151,9 @@ def test_relu(self, dtype):
2127
2151
2128
2152
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2129
2153
def test_relu6 (self , dtype ):
2154
+ if backend .backend () == "mlx" and dtype == "float64" :
2155
+ pytest .skip ("Backend does not support float64" )
2156
+
2130
2157
import jax .nn as jnn
2131
2158
import jax .numpy as jnp
2132
2159
@@ -2145,6 +2172,9 @@ def test_relu6(self, dtype):
2145
2172
2146
2173
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2147
2174
def test_selu (self , dtype ):
2175
+ if backend .backend () == "mlx" and dtype == "float64" :
2176
+ pytest .skip ("Backend does not support float64" )
2177
+
2148
2178
import jax .nn as jnn
2149
2179
import jax .numpy as jnp
2150
2180
@@ -2163,6 +2193,9 @@ def test_selu(self, dtype):
2163
2193
2164
2194
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2165
2195
def test_sigmoid (self , dtype ):
2196
+ if backend .backend () == "mlx" and dtype == "float64" :
2197
+ pytest .skip ("Backend does not support float64" )
2198
+
2166
2199
import jax .nn as jnn
2167
2200
import jax .numpy as jnp
2168
2201
@@ -2181,6 +2214,9 @@ def test_sigmoid(self, dtype):
2181
2214
2182
2215
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2183
2216
def test_silu (self , dtype ):
2217
+ if backend .backend () == "mlx" and dtype == "float64" :
2218
+ pytest .skip ("Backend does not support float64" )
2219
+
2184
2220
import jax .nn as jnn
2185
2221
import jax .numpy as jnp
2186
2222
@@ -2199,6 +2235,9 @@ def test_silu(self, dtype):
2199
2235
2200
2236
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2201
2237
def test_softplus (self , dtype ):
2238
+ if backend .backend () == "mlx" and dtype == "float64" :
2239
+ pytest .skip ("Backend does not support float64" )
2240
+
2202
2241
import jax .nn as jnn
2203
2242
import jax .numpy as jnp
2204
2243
@@ -2217,6 +2256,9 @@ def test_softplus(self, dtype):
2217
2256
2218
2257
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2219
2258
def test_softmax (self , dtype ):
2259
+ if backend .backend () == "mlx" and dtype == "float64" :
2260
+ pytest .skip ("Backend does not support float64" )
2261
+
2220
2262
import jax .nn as jnn
2221
2263
import jax .numpy as jnp
2222
2264
@@ -2235,6 +2277,9 @@ def test_softmax(self, dtype):
2235
2277
2236
2278
@parameterized .named_parameters (named_product (dtype = FLOAT_DTYPES ))
2237
2279
def test_softsign (self , dtype ):
2280
+ if backend .backend () == "mlx" and dtype == "float64" :
2281
+ pytest .skip ("Backend does not support float64" )
2282
+
2238
2283
import jax .nn as jnn
2239
2284
import jax .numpy as jnp
2240
2285
0 commit comments