Skip to content

Commit ae9bccd

Browse files
authored
add mx.round and fix numpy args (#19569)
1 parent 1edc652 commit ae9bccd

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,13 @@ def arctanh(x):
158158
return mx.arctanh(x)
159159

160160

161-
def argmax(x, axis=None):
161+
def argmax(x, axis=None, keepdims=False):
162162
x = convert_to_tensor(x)
163-
return mx.argmax(x, axis=axis)
163+
return mx.argmax(x, axis=axis, keepdims=keepdims)
164164

165165

166-
def argmin(x, axis=None):
167-
return mx.argmin(x, axis=axis)
166+
def argmin(x, axis=None, keepdims=False):
167+
return mx.argmin(x, axis=axis, keepdims=keepdims)
168168

169169

170170
def argsort(x, axis=-1):
@@ -206,7 +206,7 @@ def average(x, axis=None, weights=None):
206206
return mx.mean(x, axis=axis)
207207

208208

209-
def bincount(x, weights=None, minlength=0):
209+
def bincount(x, weights=None, minlength=0, sparse=False):
210210
raise NotImplementedError("The MLX backend doesn't support bincount yet")
211211

212212

@@ -822,7 +822,7 @@ def tensordot(x1, x2, axes=2):
822822

823823

824824
def round(x, decimals=0):
825-
raise NotImplementedError("The MLX backend doesn't support round yet")
825+
return mx.round(x, decimals=decimals)
826826

827827

828828
def tile(x, repeats):

0 commit comments

Comments
 (0)