Skip to content

Commit 2beeb0e

Browse files
authored
mlx - stft, istft, and dtype handling in numpy.py (#20907)
* stft and istft, dtype handling in numpy.py and swap jax dependency to numpy * fix bad import * formatting
1 parent f7b4397 commit 2beeb0e

File tree

7 files changed

+495
-72
lines changed

7 files changed

+495
-72
lines changed

keras/src/backend/mlx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""MLX backend APIs."""
22

33
from keras.src.backend.common.name_scope import name_scope
4-
from keras.src.backend.jax.core import random_seed_dtype
54
from keras.src.backend.mlx import core
65
from keras.src.backend.mlx import image
76
from keras.src.backend.mlx import linalg
@@ -19,6 +18,7 @@
1918
from keras.src.backend.mlx.core import convert_to_numpy
2019
from keras.src.backend.mlx.core import convert_to_tensor
2120
from keras.src.backend.mlx.core import is_tensor
21+
from keras.src.backend.mlx.core import random_seed_dtype
2222
from keras.src.backend.mlx.core import scatter
2323
from keras.src.backend.mlx.core import shape
2424
from keras.src.backend.mlx.core import stop_gradient

keras/src/backend/mlx/linalg.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import jax.numpy as jnp
21
import mlx.core as mx
2+
import numpy as np
33

44
from keras.src.backend.common import dtypes
55
from keras.src.backend.common import standardize_dtype
@@ -29,8 +29,8 @@ def det(a):
2929
return _det_3x3(a)
3030
# elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
3131
# TODO: Swap to mlx.linalg.det when supported
32-
a = jnp.array(a)
33-
output = jnp.linalg.det(a)
32+
a = np.array(a)
33+
output = np.linalg.det(a)
3434
return mx.array(output)
3535

3636

@@ -56,15 +56,26 @@ def solve_triangular(a, b, lower=False):
5656

5757

5858
def qr(x, mode="reduced"):
59-
# TODO: Swap to mlx.linalg.qr when it supports non-square matrices
60-
x = jnp.array(x)
61-
output = jnp.linalg.qr(x, mode=mode)
62-
return mx.array(output[0]), mx.array(output[1])
59+
if mode != "reduced":
60+
raise ValueError(
61+
"`mode` argument value not supported. "
62+
"Only 'reduced' is supported by the mlx backend. "
63+
f"Received: mode={mode}"
64+
)
65+
with mx.stream(mx.cpu):
66+
return mx.linalg.qr(x)
6367

6468

6569
def svd(x, full_matrices=True, compute_uv=True):
6670
with mx.stream(mx.cpu):
67-
return mx.linalg.svd(x)
71+
u, s, vt = mx.linalg.svd(x)
72+
if not compute_uv:
73+
return s
74+
if not full_matrices:
75+
n = min(x.shape[-2:])
76+
return u[..., :n], s, vt[:n, ...]
77+
# mlx returns full matrices by default
78+
return u, s, vt
6879

6980

7081
def cholesky(a):
@@ -78,11 +89,15 @@ def norm(x, ord=None, axis=None, keepdims=False):
7889
dtype = dtypes.result_type(x.dtype, "float32")
7990
x = convert_to_tensor(x, dtype=dtype)
8091
# TODO: swap to mlx.linalg.norm when it support singular value norms
81-
x = jnp.array(x)
82-
output = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
92+
x = np.array(x)
93+
output = np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
8394
return mx.array(output)
8495

8596

8697
def inv(a):
8798
with mx.stream(mx.cpu):
8899
return mx.linalg.inv(a)
100+
101+
102+
def lstsq(a, b, rcond=None):
103+
raise NotImplementedError("lstsq not yet implemented in mlx.")

0 commit comments

Comments
 (0)