@@ -71,8 +71,6 @@ def convert_to_tensor(x, dtype=None, sparse=None):
71
71
return x .value
72
72
73
73
if isinstance (x , np .ndarray ):
74
- if x .dtype == np .int64 :
75
- x = x .astype (np .int32 )
76
74
x = x .astype (standardize_dtype (x .dtype ))
77
75
return mx .array (x , dtype = mlx_dtype )
78
76
@@ -211,6 +209,10 @@ def vectorized_map(function, elements):
211
209
def scatter (indices , values , shape ):
212
210
indices = convert_to_tensor (indices )
213
211
values = convert_to_tensor (values )
212
+ if values .dtype == mx .int64 :
213
+ values = values .astype (mx .int32 )
214
+ elif values .dtype == mx .uint64 :
215
+ values = values .astype (mx .uint32 )
214
216
zeros = mx .zeros (shape , dtype = values .dtype )
215
217
indices = tuple (indices [..., i ] for i in range (indices .shape [- 1 ]))
216
218
zeros = zeros .at [indices ].add (values )
@@ -222,6 +224,10 @@ def scatter_update(inputs, indices, updates):
222
224
inputs = convert_to_tensor (inputs )
223
225
indices = convert_to_tensor (indices )
224
226
updates = convert_to_tensor (updates )
227
+ if inputs .dtype == mx .int64 :
228
+ inputs = inputs .astype (mx .int32 )
229
+ elif inputs .dtype == mx .uint64 :
230
+ inputs = inputs .astype (mx .uint32 )
225
231
indices = tuple (indices [..., i ] for i in range (indices .shape [- 1 ]))
226
232
inputs [indices ] = updates
227
233
0 commit comments