Skip to content

Commit fe4ba17

Browse files
authored
mlx - test updates and signbit implementation (#21180)
* dtype handling * unordered argpartition * passing numpy tests * passing tests * clean tests * attempt python version fix in github actions * patch to updated build behavior in layers
1 parent 65f75bb commit fe4ba17

File tree

23 files changed

+172
-24
lines changed

23 files changed

+172
-24
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
fail-fast: false
1717
matrix:
18-
python-version: [3.10]
18+
python-version: ['3.10']
1919
backend: [tensorflow, jax, torch, numpy, openvino, mlx]
2020
name: Run tests
2121
runs-on: ubuntu-latest

keras/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from keras.src.backend.common.remat import remat
4242
from keras.src.backend.common.stateless_scope import StatelessScope
4343
from keras.src.backend.common.symbolic_scope import SymbolicScope
44+
from keras.src.backend.exports import Variable
45+
from keras.src.backend.exports import device
46+
from keras.src.backend.exports import name_scope
4447
from keras.src.dtype_policies.dtype_policy import DTypePolicy
4548
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
4649
from keras.src.initializers.initializer import Initializer

keras/api/_tf_keras/keras/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from keras.src.backend.common.remat import remat
4040
from keras.src.backend.common.stateless_scope import StatelessScope
4141
from keras.src.backend.common.symbolic_scope import SymbolicScope
42+
from keras.src.backend.exports import Variable
43+
from keras.src.backend.exports import device
44+
from keras.src.backend.exports import name_scope
4245
from keras.src.dtype_policies.dtype_policy import DTypePolicy
4346
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
4447
from keras.src.initializers.initializer import Initializer

keras/src/backend/common/variables_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,11 @@ def test_add(self, dtypes):
943943
x2_jax = jnp.ones((1,), dtype=dtype2)
944944
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)
945945

946+
if backend.backend() == "mlx":
947+
if expected_dtype == "complex128":
948+
# mlx backend does not support complex128
949+
expected_dtype = "complex64"
950+
946951
self.assertDType(x1 + x2, expected_dtype)
947952
self.assertDType(x1.__radd__(x2), expected_dtype)
948953

@@ -959,6 +964,11 @@ def test_sub(self, dtypes):
959964
x2_jax = jnp.ones((1,), dtype=dtype2)
960965
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)
961966

967+
if backend.backend() == "mlx":
968+
if expected_dtype == "complex128":
969+
# mlx backend does not support complex128
970+
expected_dtype = "complex64"
971+
962972
self.assertDType(x1 - x2, expected_dtype)
963973
self.assertDType(x1.__rsub__(x2), expected_dtype)
964974

@@ -975,6 +985,11 @@ def test_mul(self, dtypes):
975985
x2_jax = jnp.ones((1,), dtype=dtype2)
976986
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)
977987

988+
if backend.backend() == "mlx":
989+
if expected_dtype == "complex128":
990+
# mlx backend does not have complex128
991+
expected_dtype = "complex64"
992+
978993
self.assertDType(x1 * x2, expected_dtype)
979994
self.assertDType(x1.__rmul__(x2), expected_dtype)
980995

@@ -1059,13 +1074,26 @@ def test_pow(self, dtypes):
10591074
x2_jax = jnp.ones((1,), dtype=dtype2)
10601075
expected_dtype = standardize_dtype(jnp.power(x1_jax, x2_jax).dtype)
10611076

1077+
if backend.backend() == "mlx":
1078+
if expected_dtype == "complex128":
1079+
# mlx backend does not support complex128
1080+
expected_dtype = "complex64"
1081+
10621082
self.assertDType(x1**x2, expected_dtype)
10631083
self.assertDType(x1.__rpow__(x2), expected_dtype)
10641084

10651085
@parameterized.named_parameters(
10661086
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
10671087
)
10681088
def test_matmul(self, dtypes):
1089+
if backend.backend() == "mlx":
1090+
result_dtype = backend.result_type(*dtypes)
1091+
if "float" not in result_dtype:
1092+
self.skipTest(
1093+
"mlx backend only supports matmul for real floating point "
1094+
"types"
1095+
)
1096+
10691097
import jax.numpy as jnp
10701098

10711099
dtype1, dtype2 = dtypes

keras/src/backend/mlx/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import warnings
44

5+
import ml_dtypes
56
import mlx.core as mx
67
import numpy as np
78

@@ -97,10 +98,14 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
9798
if x.dtype == np.float64:
9899
# mlx backend does not support float64
99100
x = x.astype(np.float32)
100-
if standardize_dtype(x.dtype) == "bfloat16" and mlx_dtype is None:
101+
if standardize_dtype(x.dtype) == "bfloat16":
102+
# mlx currently fails to load a numpy array with dtype=bfloat16
103+
# upcast to float32 to avoid error
104+
x = x.astype(np.float32)
101105
# if a bfloat16 np.ndarray is passed to mx.array with dtype=None
102106
# it casts the output to complex64, so we force cast to bfloat16
103-
mlx_dtype = mx.bfloat16
107+
# (but by upcasting we avoid x.dtype=bfloat16 and mlx_dtype=None)
108+
mlx_dtype = mx.bfloat16 if mlx_dtype is None else mlx_dtype
104109
return mx.array(x, dtype=mlx_dtype)
105110

106111
if isinstance(x, list):
@@ -154,8 +159,9 @@ def convert_to_tensors(*xs):
154159
def convert_to_numpy(x):
155160
# Performs a copy. If we want 0-copy we can pass copy=False
156161
if isinstance(x, mx.array) and x.dtype == mx.bfloat16:
157-
# mlx currently has an error passing bloat16 array to numpy
158-
return np.array(x.astype(mx.float32))
162+
# mlx currently has an error passing bfloat16 array to numpy
163+
# upcast to float32 then downcast to bfloat16
164+
return np.array(x.astype(mx.float32)).astype(ml_dtypes.bfloat16)
159165
return np.array(x)
160166

161167

keras/src/backend/mlx/image.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,3 +655,35 @@ def _compute_weight_mat(
655655
weights,
656656
0,
657657
)
658+
659+
660+
def elastic_transform(
661+
images,
662+
alpha=20.0,
663+
sigma=5.0,
664+
interpolation="bilinear",
665+
fill_mode="reflect",
666+
fill_value=0.0,
667+
seed=None,
668+
data_format=None,
669+
):
670+
raise NotImplementedError("elastic_transform not yet implemented in mlx.")
671+
672+
673+
def perspective_transform(
674+
images,
675+
start_points,
676+
end_points,
677+
interpolation="bilinear",
678+
fill_value=0,
679+
data_format=None,
680+
):
681+
raise NotImplementedError(
682+
"perspective_transform not yet implemented in mlx."
683+
)
684+
685+
686+
def gaussian_blur(
687+
images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None
688+
):
689+
raise NotImplementedError("gaussian_blur not yet implemented in mlx.")

keras/src/backend/mlx/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,24 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):
6060

6161

6262
def top_k(x, k, sorted=True):
63+
# default to sorted=True to match other backends
6364
x = convert_to_tensor(x)
6465
indices = mx.argpartition(mx.negative(x), k, axis=-1)[..., :k]
6566
values = mx.take_along_axis(x, indices, axis=-1)
67+
68+
if sorted:
69+
sort_indices = mx.argsort(mx.negative(values), axis=-1)
70+
values = mx.take_along_axis(values, sort_indices, axis=-1)
71+
indices = mx.take_along_axis(indices, sort_indices, axis=-1)
72+
6673
return values, indices
6774

6875

6976
def in_top_k(targets, predictions, k):
7077
targets = convert_to_tensor(targets)
7178
predictions = convert_to_tensor(predictions)
7279
targets = targets[..., None]
73-
topk_values = top_k(predictions, k)[0]
80+
topk_values = top_k(predictions, k, sorted=False)[0]
7481
targets_values = mx.take_along_axis(predictions, targets, axis=-1)
7582
mask = targets_values >= topk_values
7683
return mx.any(mask, axis=-1)

keras/src/backend/mlx/nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,7 @@ def dot_product_attention(
12851285
scale=None,
12861286
is_causal=False,
12871287
flash_attention=None,
1288+
attn_logits_soft_cap=None,
12881289
):
12891290
if flash_attention is None:
12901291
flash_attention = False

keras/src/backend/mlx/numpy.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.src.backend.mlx.core import cast
1313
from keras.src.backend.mlx.core import convert_to_tensor
1414
from keras.src.backend.mlx.core import convert_to_tensors
15+
from keras.src.backend.mlx.core import is_tensor
1516
from keras.src.backend.mlx.core import slice
1617
from keras.src.backend.mlx.core import to_mlx_dtype
1718

@@ -272,8 +273,20 @@ def bitwise_xor(x, y):
272273

273274
def bitwise_left_shift(x, y):
274275
x = convert_to_tensor(x)
275-
y = convert_to_tensor(y)
276-
return mx.left_shift(x, y)
276+
if not isinstance(y, int):
277+
y = convert_to_tensor(y)
278+
279+
# handle result dtype to match other backends
280+
types = [x.dtype]
281+
if is_tensor(y):
282+
types.append(y.dtype)
283+
result_dtype = result_type(*types)
284+
mlx_result_dtype = to_mlx_dtype(result_dtype)
285+
286+
result = mx.left_shift(x, y)
287+
if result.dtype != mlx_result_dtype:
288+
return result.astype(mlx_result_dtype)
289+
return result
277290

278291

279292
def left_shift(x, y):
@@ -282,8 +295,20 @@ def left_shift(x, y):
282295

283296
def bitwise_right_shift(x, y):
284297
x = convert_to_tensor(x)
285-
y = convert_to_tensor(y)
286-
return mx.right_shift(x, y)
298+
if not isinstance(y, int):
299+
y = convert_to_tensor(y)
300+
301+
# handle result dtype to match other backends
302+
types = [x.dtype]
303+
if is_tensor(y):
304+
types.append(y.dtype)
305+
result_dtype = result_type(*types)
306+
mlx_result_dtype = to_mlx_dtype(result_dtype)
307+
308+
result = mx.right_shift(x, y)
309+
if result.dtype != mlx_result_dtype:
310+
return result.astype(mlx_result_dtype)
311+
return result
287312

288313

289314
def right_shift(x, y):
@@ -1567,3 +1592,34 @@ def rot90(array, k=1, axes=(0, 1)):
15671592
array = array[tuple(slices)]
15681593

15691594
return array
1595+
1596+
1597+
def signbit(x):
1598+
x = convert_to_tensor(x)
1599+
1600+
if x.dtype in (
1601+
mx.float16,
1602+
mx.float32,
1603+
mx.float64,
1604+
mx.bfloat16,
1605+
mx.complex64,
1606+
):
1607+
if x.dtype == mx.complex64:
1608+
# check sign of real part for complex numbers
1609+
real_part = mx.real(x)
1610+
return signbit(real_part)
1611+
zeros = x == 0
1612+
# this works because in mlx 1/0=inf and 1/-0=-inf
1613+
neg_zeros = (1 / x == mx.array(float("-inf"))) & zeros
1614+
return mx.where(zeros, neg_zeros, x < 0)
1615+
elif x.dtype in (mx.uint8, mx.uint16, mx.uint32, mx.uint64):
1616+
# unsigned integers never negative
1617+
return mx.zeros_like(x).astype(mx.bool_)
1618+
elif x.dtype in (mx.int8, mx.int16, mx.int32, mx.int64):
1619+
# for integers, simple negative check
1620+
return x < 0
1621+
elif x.dtype == mx.bool_:
1622+
# for boolean array, return false
1623+
return mx.zeros_like(x).astype(mx.bool_)
1624+
else:
1625+
raise ValueError(f"Unsupported dtype in `signbit`: {x.dtype}")

keras/src/initializers/constant_initializers_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ def test_stft_initializer(self):
8484
args = -2 * pi * time_range * freq_range
8585

8686
tol_kwargs = {}
87-
if backend.backend() == "jax":
87+
if backend.backend() == "jax" or backend.backend() == "mlx":
8888
# TODO(mostafa-mahmoud): investigate the cases
8989
# of non-small error in jax and torch
90+
# for mlx, minor precision differences with float64 on linux
9091
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}
9192

9293
initializer = initializers.STFT("real", None)

0 commit comments

Comments
 (0)