@@ -1379,7 +1379,28 @@ def histogram(x, bins, range):
1379
1379
1380
1380
1381
1381
def unravel_index (x , shape ):
1382
- raise NotImplementedError ("unravel_index not yet implemented in mlx." )
1382
+ x = convert_to_tensor (x )
1383
+ input_dtype = x .dtype
1384
+
1385
+ if None in shape :
1386
+ raise ValueError (
1387
+ "`shape` argument cannot contain `None`. Received: shape={shape}"
1388
+ )
1389
+
1390
+ if x .ndim == 1 :
1391
+ coords = []
1392
+ for dim in reversed (shape ):
1393
+ coords .append ((x % dim ).astype (input_dtype ))
1394
+ x = x // dim
1395
+ return tuple (reversed (coords ))
1396
+
1397
+ x_shape = x .shape
1398
+ coords = []
1399
+ for dim in shape :
1400
+ coords .append (mx .reshape ((x % dim ).astype (input_dtype ), x_shape ))
1401
+ x = x // dim
1402
+
1403
+ return tuple (reversed (coords ))
1383
1404
1384
1405
1385
1406
def searchsorted (sorted_sequence , values , side = "left" ):
@@ -1391,4 +1412,46 @@ def diagflat(x, k=0):
1391
1412
1392
1413
1393
1414
def rot90 (array , k = 1 , axes = (0 , 1 )):
1394
- raise NotImplementedError ("rot90 not yet implemented in mlx." )
1415
+ array = convert_to_tensor (array )
1416
+
1417
+ if array .ndim < 2 :
1418
+ raise ValueError (
1419
+ f"Input array must have at least 2 dimensions. "
1420
+ f"Received: array.ndim={ array .ndim } "
1421
+ )
1422
+ if len (axes ) != 2 or axes [0 ] == axes [1 ]:
1423
+ raise ValueError (
1424
+ f"Invalid axes: { axes } . Axes must be a tuple of "
1425
+ "two different dimensions."
1426
+ )
1427
+
1428
+ array_axes = list (range (array .ndim ))
1429
+ # Swap axes
1430
+ array_axes [axes [0 ]], array_axes [axes [1 ]] = (
1431
+ array_axes [axes [1 ]],
1432
+ array_axes [axes [0 ]],
1433
+ )
1434
+
1435
+ if k < 0 :
1436
+ axes = (axes [1 ], axes [0 ])
1437
+ k *= - 1
1438
+
1439
+ k = k % 4
1440
+
1441
+ if k > 0 :
1442
+ slices = [builtins .slice (None ) for _ in range (array .ndim )]
1443
+ if k == 2 :
1444
+ # 180 deg rotation => reverse elements along both axes
1445
+ slices [axes [0 ]] = builtins .slice (None , None , - 1 )
1446
+ slices [axes [1 ]] = builtins .slice (None , None , - 1 )
1447
+ else :
1448
+ # 90 or 270 deg rotation => transpose and reverse along one axis
1449
+ array = mx .transpose (array , axes = array_axes )
1450
+ if k == 1 :
1451
+ slices [axes [0 ]] = builtins .slice (None , None , - 1 )
1452
+ else :
1453
+ slices [axes [1 ]] = builtins .slice (None , None , - 1 )
1454
+
1455
+ array = array [tuple (slices )]
1456
+
1457
+ return array
0 commit comments