Skip to content

Commit 89f6103

Browse files
authored
Support diagflat, solve and fix trunc function (#20948)
1 parent c77a339 commit 89f6103

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

keras/src/backend/mlx/linalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def lu_factor(a):
4848

4949

5050
def solve(a, b):
51-
raise NotImplementedError("solve_triangular not yet implemented in mlx.")
51+
with mx.stream(mx.cpu):
52+
# [linalg::solve] This op is not yet supported on the GPU.
53+
# Explicitly pass a CPU stream to run it.
54+
return mx.linalg.solve(a, b)
5255

5356

5457
def solve_triangular(a, b, lower=False):

keras/src/backend/mlx/numpy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def trunc(x):
11511151
dtype = standardize_dtype(x.dtype)
11521152
if "int" in dtype or "bool" == dtype:
11531153
return x
1154-
return mx.floor(x)
1154+
return mx.where(x < 0, mx.ceil(x), mx.floor(x))
11551155

11561156

11571157
def vdot(x1, x2):
@@ -1514,7 +1514,13 @@ def searchsorted(sorted_sequence, values, side="left"):
15141514

15151515

15161516
def diagflat(x, k=0):
1517-
raise NotImplementedError("diagflat not yet implemented in mlx.")
1517+
x = convert_to_tensor(x)
1518+
1519+
# GPU scatter does not yet support int64 or complex64
1520+
# for the input or updates.
1521+
stream = mx.cpu if x.dtype in [mx.int64, mx.complex64] else None
1522+
1523+
return mx.diag(mx.reshape(x, [-1]), k, stream=stream)
15181524

15191525

15201526
def rot90(array, k=1, axes=(0, 1)):

0 commit comments

Comments
 (0)