Skip to content

Commit 3d7ef85

Browse files
author
AyoubMDL
committed
fix(index_put): handle multidimensional indices
1 parent d75122f commit 3d7ef85

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import math
1515
from typing import Any, Optional, Sequence, Tuple, Union
1616

17+
import numpy as np
18+
1719
from onnxscript import (
1820
BFLOAT16,
1921
BOOL,
@@ -4303,17 +4305,21 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43034305
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
43044306
reshape_list.remove(1)
43054307

4308+
# Or add ones until the rank of reshape_list matches values_shape.
4309+
while len(reshape_list) < len(values_shape):
4310+
reshape_list.append(1)
4311+
43064312
# Now ensure each dimension is broadcastable:
43074313
# This is mandatory when mixing basic and advanced indexing
43084314
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
43094315
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
43104316
for i, r in enumerate(reshape_list):
4311-
if r != 1 and r != values_shape[i]:
4312-
one_index = reshape_list.index(1)
4317+
if r not in (1, values_shape[i]):
4318+
value_index = values_shape.index(r)
43134319
# Swap elements
43144320
# For the example above the current reshape list is [1, 2] for last dim,
43154321
# to make it broadcastable, we swap the elements
4316-
reshape_list[one_index], reshape_list[i] = reshape_list[i], 1
4322+
reshape_list[value_index], reshape_list[i] = r, 1
43174323

43184324
return reshape_list
43194325

@@ -4322,8 +4328,8 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43224328
if len(indices) < self_rank:
43234329
indices = list(indices) + [None] * (self_rank - len(indices))
43244330

4325-
# Get values shape (we use .numpy to make it hashable)
4326-
values_shape = values.shape.numpy()
4331+
# Get values shape
4332+
values_shape = tuple(values.shape)
43274333

43284334
index_vectors = []
43294335
for i in range(self_rank):
@@ -4333,7 +4339,15 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43334339
reshape_update = self.shape[i]
43344340
else:
43354341
idx = indices[i]
4336-
reshape_update = indices[i].shape[0]
4342+
reshape_update = np.prod(idx.shape).item()
4343+
# when Index is more than 1D, flatten it and also the values shape
4344+
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
4345+
# Indices -> (2*4,) and values shape (2*4, 32)
4346+
if len(idx.shape) > 1:
4347+
values_shape = (reshape_update,) + values_shape[len(idx.shape) :]
4348+
4349+
# Flatten index (always working with 1D index in each dim)
4350+
idx = op.Reshape(idx, [-1])
43374351

43384352
# Create a reshape pattern: one value per index dimension,
43394353
# with the current dimension set to the update size.

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,12 @@ def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
834834
],
835835
(1,),
836836
),
837+
# Cases: Multidimensional index
838+
(
839+
(10, 3),
840+
[torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))],
841+
(2, 4, 3),
842+
),
837843
]
838844

839845
for data_shape, indices, values_shape in cases: # type: ignore[misc]

0 commit comments

Comments
 (0)