Skip to content

Commit 587b4c9

Browse files
authored
added tensordot (#19566)
1 parent 3859fb4 commit 587b4c9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,9 @@ def tanh(x):
809809

810810

811811
def tensordot(x1, x2, axes=2):
812-
raise NotImplementedError("The MLX backend doesn't support tensordot yet")
812+
x1 = convert_to_tensor(x1)
813+
x2 = convert_to_tensor(x2)
814+
return mx.tensordot(x1, x2, axes=axes)
813815

814816

815817
def round(x, decimals=0):

0 commit comments

Comments
 (0)