1
- import jax .numpy as jnp
2
1
import mlx .core as mx
2
+ import numpy as np
3
3
4
4
from keras .src .backend .common import dtypes
5
5
from keras .src .backend .common import standardize_dtype
@@ -29,8 +29,8 @@ def det(a):
29
29
return _det_3x3 (a )
30
30
# elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
31
31
# TODO: Swap to mlx.linalg.det when supported
32
- a = jnp .array (a )
33
- output = jnp .linalg .det (a )
32
+ a = np .array (a )
33
+ output = np .linalg .det (a )
34
34
return mx .array (output )
35
35
36
36
@@ -56,15 +56,26 @@ def solve_triangular(a, b, lower=False):
56
56
57
57
58
58
def qr (x , mode = "reduced" ):
59
- # TODO: Swap to mlx.linalg.qr when it supports non-square matrices
60
- x = jnp .array (x )
61
- output = jnp .linalg .qr (x , mode = mode )
62
- return mx .array (output [0 ]), mx .array (output [1 ])
59
+ if mode != "reduced" :
60
+ raise ValueError (
61
+ "`mode` argument value not supported. "
62
+ "Only 'reduced' is supported by the mlx backend. "
63
+ f"Received: mode={ mode } "
64
+ )
65
+ with mx .stream (mx .cpu ):
66
+ return mx .linalg .qr (x )
63
67
64
68
65
69
def svd (x , full_matrices = True , compute_uv = True ):
66
70
with mx .stream (mx .cpu ):
67
- return mx .linalg .svd (x )
71
+ u , s , vt = mx .linalg .svd (x )
72
+ if not compute_uv :
73
+ return s
74
+ if not full_matrices :
75
+ n = min (x .shape [- 2 :])
76
+ return u [..., :n ], s , vt [:n , ...]
77
+ # mlx returns full matrices by default
78
+ return u , s , vt
68
79
69
80
70
81
def cholesky (a ):
@@ -78,11 +89,15 @@ def norm(x, ord=None, axis=None, keepdims=False):
78
89
dtype = dtypes .result_type (x .dtype , "float32" )
79
90
x = convert_to_tensor (x , dtype = dtype )
80
91
# TODO: swap to mlx.linalg.norm when it support singular value norms
81
- x = jnp .array (x )
82
- output = jnp .linalg .norm (x , ord = ord , axis = axis , keepdims = keepdims )
92
+ x = np .array (x )
93
+ output = np .linalg .norm (x , ord = ord , axis = axis , keepdims = keepdims )
83
94
return mx .array (output )
84
95
85
96
86
97
def inv (a ):
87
98
with mx .stream (mx .cpu ):
88
99
return mx .linalg .inv (a )
100
+
101
+
102
+ def lstsq (a , b , rcond = None ):
103
+ raise NotImplementedError ("lstsq not yet implemented in mlx." )
0 commit comments