@@ -158,13 +158,13 @@ def arctanh(x):
158
158
return mx .arctanh (x )
159
159
160
160
161
- def argmax (x , axis = None ):
161
+ def argmax (x , axis = None , keepdims = False ):
162
162
x = convert_to_tensor (x )
163
- return mx .argmax (x , axis = axis )
163
+ return mx .argmax (x , axis = axis , keepdims = keepdims )
164
164
165
165
166
- def argmin (x , axis = None ):
167
- return mx .argmin (x , axis = axis )
166
+ def argmin (x , axis = None , keepdims = False ):
167
+ return mx .argmin (x , axis = axis , keepdims = keepdims )
168
168
169
169
170
170
def argsort (x , axis = - 1 ):
@@ -206,7 +206,7 @@ def average(x, axis=None, weights=None):
206
206
return mx .mean (x , axis = axis )
207
207
208
208
209
- def bincount (x , weights = None , minlength = 0 ):
209
+ def bincount (x , weights = None , minlength = 0 , sparse = False ):
210
210
raise NotImplementedError ("The MLX backend doesn't support bincount yet" )
211
211
212
212
@@ -822,7 +822,7 @@ def tensordot(x1, x2, axes=2):
822
822
823
823
824
824
def round (x , decimals = 0 ):
825
- raise NotImplementedError ( "The MLX backend doesn't support round yet" )
825
+ return mx . round ( x , decimals = decimals )
826
826
827
827
828
828
def tile (x , repeats ):
0 commit comments