Skip to content

Commit 45e5771

Browse files
authored
Implementation of rot90 and unravel_index functions (#20909)
1 parent 2beeb0e commit 45e5771

File tree

1 file changed

+65
-2
lines changed

1 file changed

+65
-2
lines changed

keras/src/backend/mlx/numpy.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,28 @@ def histogram(x, bins, range):
13791379

13801380

13811381
def unravel_index(x, shape):
1382-
raise NotImplementedError("unravel_index not yet implemented in mlx.")
1382+
x = convert_to_tensor(x)
1383+
input_dtype = x.dtype
1384+
1385+
if None in shape:
1386+
raise ValueError(
1387+
"`shape` argument cannot contain `None`. Received: shape={shape}"
1388+
)
1389+
1390+
if x.ndim == 1:
1391+
coords = []
1392+
for dim in reversed(shape):
1393+
coords.append((x % dim).astype(input_dtype))
1394+
x = x // dim
1395+
return tuple(reversed(coords))
1396+
1397+
x_shape = x.shape
1398+
coords = []
1399+
for dim in shape:
1400+
coords.append(mx.reshape((x % dim).astype(input_dtype), x_shape))
1401+
x = x // dim
1402+
1403+
return tuple(reversed(coords))
13831404

13841405

13851406
def searchsorted(sorted_sequence, values, side="left"):
@@ -1391,4 +1412,46 @@ def diagflat(x, k=0):
13911412

13921413

13931414
def rot90(array, k=1, axes=(0, 1)):
1394-
raise NotImplementedError("rot90 not yet implemented in mlx.")
1415+
array = convert_to_tensor(array)
1416+
1417+
if array.ndim < 2:
1418+
raise ValueError(
1419+
f"Input array must have at least 2 dimensions. "
1420+
f"Received: array.ndim={array.ndim}"
1421+
)
1422+
if len(axes) != 2 or axes[0] == axes[1]:
1423+
raise ValueError(
1424+
f"Invalid axes: {axes}. Axes must be a tuple of "
1425+
"two different dimensions."
1426+
)
1427+
1428+
array_axes = list(range(array.ndim))
1429+
# Swap axes
1430+
array_axes[axes[0]], array_axes[axes[1]] = (
1431+
array_axes[axes[1]],
1432+
array_axes[axes[0]],
1433+
)
1434+
1435+
if k < 0:
1436+
axes = (axes[1], axes[0])
1437+
k *= -1
1438+
1439+
k = k % 4
1440+
1441+
if k > 0:
1442+
slices = [builtins.slice(None) for _ in range(array.ndim)]
1443+
if k == 2:
1444+
# 180 deg rotation => reverse elements along both axes
1445+
slices[axes[0]] = builtins.slice(None, None, -1)
1446+
slices[axes[1]] = builtins.slice(None, None, -1)
1447+
else:
1448+
# 90 or 270 deg rotation => transpose and reverse along one axis
1449+
array = mx.transpose(array, axes=array_axes)
1450+
if k == 1:
1451+
slices[axes[0]] = builtins.slice(None, None, -1)
1452+
else:
1453+
slices[axes[1]] = builtins.slice(None, None, -1)
1454+
1455+
array = array[tuple(slices)]
1456+
1457+
return array

0 commit comments

Comments
 (0)