@@ -1146,6 +1146,10 @@ def test_argmax(self):
1146
1146
keras .config .backend () == "openvino" ,
1147
1147
reason = "OpenVINO doesn't support this change" ,
1148
1148
)
1149
+ @pytest .mark .skipif (
1150
+ keras .config .backend () == "mlx" ,
1151
+ reason = "Wrong results due to MLX flushing denormal numbers to 0 on GPU" ,
1152
+ )
1149
1153
def test_argmax_negative_zero (self ):
1150
1154
input_data = np .array (
1151
1155
[- 1.0 , - 0.0 , 1.401298464324817e-45 ], dtype = np .float32
@@ -1161,6 +1165,10 @@ def test_argmax_negative_zero(self):
1161
1165
evaluation and may change within this PR
1162
1166
""" ,
1163
1167
)
1168
+ @pytest .mark .skipif (
1169
+ keras .config .backend () == "mlx" ,
1170
+ reason = "Wrong results due to MLX flushing denormal numbers to 0 on GPU" ,
1171
+ )
1164
1172
def test_argmin_negative_zero (self ):
1165
1173
input_data = np .array (
1166
1174
[
@@ -5391,10 +5399,16 @@ def setUp(self):
5391
5399
5392
5400
self .jax_enable_x64 = enable_x64 ()
5393
5401
self .jax_enable_x64 .__enter__ ()
5402
+
5403
+ if backend .backend () == "mlx" :
5404
+ self .mlx_cpu_context = backend .core .enable_float64 ()
5405
+ self .mlx_cpu_context .__enter__ ()
5394
5406
return super ().setUp ()
5395
5407
5396
5408
def tearDown (self ):
5397
5409
self .jax_enable_x64 .__exit__ (None , None , None )
5410
+ if backend .backend () == "mlx" :
5411
+ self .mlx_cpu_context .__exit__ (None , None , None )
5398
5412
return super ().tearDown ()
5399
5413
5400
5414
@parameterized .named_parameters (
@@ -5598,6 +5612,13 @@ def test_matmul(self, dtypes):
5598
5612
import jax .numpy as jnp
5599
5613
5600
5614
dtype1 , dtype2 = dtypes
5615
+ if (
5616
+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
5617
+ and backend .backend () == "mlx"
5618
+ ):
5619
+ # This must be removed once mlx.core.matmul supports integer dtypes
5620
+ self .skipTest ("mlx doesn't support integer dot product" )
5621
+
5601
5622
# The shape of the matrix needs to meet the requirements of
5602
5623
# torch._int_mm to test hardware-accelerated matmul
5603
5624
x1 = knp .ones ((17 , 16 ), dtype = dtype1 )
@@ -6620,6 +6641,13 @@ def test_dot(self, dtypes):
6620
6641
import jax .numpy as jnp
6621
6642
6622
6643
dtype1 , dtype2 = dtypes
6644
+ if (
6645
+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
6646
+ and backend .backend () == "mlx"
6647
+ ):
6648
+ # This must be removed once mlx.core.matmul supports integer dtypes
6649
+ self .skipTest ("mlx doesn't support integer dot product" )
6650
+
6623
6651
x1 = knp .ones ((2 , 3 , 4 ), dtype = dtype1 )
6624
6652
x2 = knp .ones ((4 , 3 ), dtype = dtype2 )
6625
6653
x1_jax = jnp .ones ((2 , 3 , 4 ), dtype = dtype1 )
@@ -6648,6 +6676,13 @@ def get_input_shapes(subscripts):
6648
6676
return x1_shape , x2_shape
6649
6677
6650
6678
dtype1 , dtype2 = dtypes
6679
+ if (
6680
+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
6681
+ and backend .backend () == "mlx"
6682
+ ):
6683
+ # This must be removed once mlx.core.matmul supports integer dtypes
6684
+ self .skipTest ("mlx doesn't support integer dot product" )
6685
+
6651
6686
subscripts = "ijk,lkj->il"
6652
6687
x1_shape , x2_shape = get_input_shapes (subscripts )
6653
6688
x1 = knp .ones (x1_shape , dtype = dtype1 )
@@ -8312,6 +8347,13 @@ def test_tensordot(self, dtypes):
8312
8347
import jax .numpy as jnp
8313
8348
8314
8349
dtype1 , dtype2 = dtypes
8350
+ if (
8351
+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
8352
+ and backend .backend () == "mlx"
8353
+ ):
8354
+ # This must be removed once mlx.core.matmul supports integer dtypes
8355
+ self .skipTest ("mlx doesn't support integer dot product" )
8356
+
8315
8357
x1 = knp .ones ((1 , 1 ), dtype = dtype1 )
8316
8358
x2 = knp .ones ((1 , 1 ), dtype = dtype2 )
8317
8359
x1_jax = jnp .ones ((1 , 1 ), dtype = dtype1 )
@@ -8522,6 +8564,13 @@ def test_inner(self, dtypes):
8522
8564
import jax .numpy as jnp
8523
8565
8524
8566
dtype1 , dtype2 = dtypes
8567
+ if (
8568
+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
8569
+ and backend .backend () == "mlx"
8570
+ ):
8571
+ # This must be removed once mlx.core.matmul supports integer dtypes
8572
+ self .skipTest ("mlx doesn't support integer dot product" )
8573
+
8525
8574
x1 = knp .ones ((1 ,), dtype = dtype1 )
8526
8575
x2 = knp .ones ((1 ,), dtype = dtype2 )
8527
8576
x1_jax = jnp .ones ((1 ,), dtype = dtype1 )
0 commit comments