Skip to content

Commit 6424c61

Browse files
authored
add missing convert_to_tensor (#19578)
1 parent d22e827 commit 6424c61

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

keras/src/backend/mlx/math.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def in_top_k(targets, predictions, k):
3131

3232

3333
def logsumexp(x, axis=None, keepdims=False):
34+
x = convert_to_tensor(x)
3435
return mx.logsumexp(x, axis, keepdims)
3536

3637

@@ -87,10 +88,12 @@ def istft(
8788

8889

8990
def rsqrt(x):
91+
x = convert_to_tensor(x)
9092
return mx.rsqrt(x)
9193

9294

9395
def erf(x):
96+
x = convert_to_tensor(x)
9497
return mx.erf(x)
9598

9699

keras/src/backend/mlx/nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ def moments(x, axes, keepdims=False, synchronized=False):
308308
def batch_normalization(
309309
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
310310
):
311+
x = convert_to_tensor(x)
312+
mean = convert_to_tensor(x)
311313
shape = [1] * len(x.shape)
312314
shape[axis] = mean.shape[0]
313315
mean = mx.reshape(mean, shape)

keras/src/backend/mlx/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def absolute(x):
8282

8383

8484
def abs(x):
85+
x = convert_to_tensor(x)
8586
return absolute(x)
8687

8788

@@ -164,10 +165,12 @@ def argmax(x, axis=None, keepdims=False):
164165

165166

166167
def argmin(x, axis=None, keepdims=False):
168+
x = convert_to_tensor(x)
167169
return mx.argmin(x, axis=axis, keepdims=keepdims)
168170

169171

170172
def argsort(x, axis=-1):
173+
x = convert_to_tensor(x)
171174
return mx.argsort(x, axis=axis)
172175

173176

@@ -235,6 +238,7 @@ def conjugate(x):
235238

236239

237240
def conj(x):
241+
x = convert_to_tensor(x)
238242
return conjugate(x)
239243

240244

@@ -916,6 +920,8 @@ def divide_no_nan(x1, x2):
916920

917921

918922
def true_divide(x1, x2):
923+
x1 = convert_to_tensor(x1)
924+
x2 = convert_to_tensor(x2)
919925
return divide(x1, x2)
920926

921927

0 commit comments

Comments
 (0)