Skip to content

Commit 8b7997d

Browse files
authored
mlx - fix for ops/core_test.py (#19599)
* fix for ops/core_test.py * revert cast int64 to int32 removal
1 parent 817619e commit 8b7997d

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

keras/src/backend/mlx/core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,22 +230,37 @@ def scatter_update(inputs, indices, updates):
230230

231231
def slice(inputs, start_indices, shape):
232232
inputs = convert_to_tensor(inputs)
233-
233+
if not isinstance(shape, list):
234+
shape = convert_to_tensor(shape, dtype="int32").tolist()
235+
else:
236+
shape = [i if isinstance(i, int) else i.item() for i in shape]
237+
if not isinstance(start_indices, list):
238+
start_indices = convert_to_tensor(start_indices, dtype="int32").tolist()
239+
else:
240+
start_indices = [
241+
i if isinstance(i, int) else i.item() for i in start_indices
242+
]
234243
python_slice = __builtins__["slice"]
235244
slices = tuple(
236-
python_slice(int(start_index), int(start_index + length))
245+
python_slice(start_index, start_index + length)
237246
for start_index, length in zip(start_indices, shape)
238247
)
239248
return inputs[slices]
240249

241250

242251
def slice_update(inputs, start_indices, updates):
243252
inputs = convert_to_tensor(inputs)
253+
if not isinstance(start_indices, list):
254+
start_indices = convert_to_tensor(start_indices, dtype="int32").tolist()
255+
else:
256+
start_indices = [
257+
i if isinstance(i, int) else i.item() for i in start_indices
258+
]
244259
updates = convert_to_tensor(updates)
245260

246261
python_slice = __builtins__["slice"]
247262
slices = tuple(
248-
python_slice(int(start_index), int(start_index + update_length))
263+
python_slice(start_index, start_index + update_length)
249264
for start_index, update_length in zip(start_indices, updates.shape)
250265
)
251266
inputs[slices] = updates

keras/src/ops/core_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ def test_cast(self):
468468
@parameterized.named_parameters(
469469
("float8_e4m3fn", "float8_e4m3fn"), ("float8_e5m2", "float8_e5m2")
470470
)
471+
@pytest.mark.skipif(
472+
backend.backend() == "mlx",
473+
reason=f"{backend.backend()} doesn't support `float8`.",
474+
)
471475
def test_cast_float8(self, float8_dtype):
472476
# Cast to float8 and cast back
473477
x = ops.ones((2,), dtype="float32")
@@ -584,6 +588,8 @@ class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase):
584588
ALL_DTYPES = [
585589
x for x in ALL_DTYPES if x not in ["uint16", "uint32", "uint64"]
586590
]
591+
elif backend.backend() == "mlx":
592+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["float64"]]
587593
# Remove float8 dtypes for the following tests
588594
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
589595

0 commit comments

Comments
 (0)