Skip to content

Commit 6ea89ae

Browse files
committed
Fixed a typo in the converter. Covered the discontinuous tests
1 parent a016bc0 commit 6ea89ae

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def index_put_converter(
713713
for dim in range(rank):
714714
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
715715
if dim in I:
716-
idx_tensor = I_combined[i]
716+
idx_tensor = I_combined[i_idx]
717717
ii_list.append(idx_tensor)
718718
i_idx += 1
719719
else:
@@ -771,7 +771,12 @@ def index_put_converter(
771771
)
772772
else: # Non-scalar case
773773
values_shape = list(values.shape)
774-
if K > 0 and N in values_shape:
774+
if (
775+
K > 0
776+
and N in values_shape
777+
and (len(F) > 1 and max(F) - min(F) + 1 == len(F))
778+
):
779+
# Continuous case
775780
n_idx = values_shape.index(N)
776781
permute_order = [n_idx] + [
777782
i for i in range(len(values_shape)) if i != n_idx
@@ -817,6 +822,7 @@ def index_put_converter(
817822
tuple(broadcast_shape),
818823
)
819824
else:
825+
# Discontinuous case
820826
values_shape_padded = [1] * (
821827
len(expected_shape) - len(values.shape)
822828
) + list(values.shape)
@@ -828,20 +834,20 @@ def index_put_converter(
828834
raise ValueError(
829835
f"Cannot broadcast {values.shape} to {expected_shape}"
830836
)
831-
values_reshaped = impl.shuffle.reshape(
832-
ctx,
833-
target,
834-
source_ir,
835-
f"{name}_reshape_values",
836-
values,
837-
tuple(broadcast_shape),
838-
)
837+
# values_reshaped = impl.shuffle.reshape(
838+
# ctx,
839+
# target,
840+
# source_ir,
841+
# f"{name}_reshape_values",
842+
# values,
843+
# tuple(broadcast_shape),
844+
# )
839845
values_expanded = impl.slice.expand(
840846
ctx,
841847
target,
842848
source_ir,
843849
f"{name}_expand_values",
844-
values_reshaped,
850+
values,
845851
expected_shape,
846852
)
847853

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,43 @@ class TestIndexPutConverter(DispatchTestCase):
195195
dtype=torch.int32,
196196
),
197197
),
198+
# param(
199+
# test_name="4d_indices_none_none_multiple_idx_broadcast_error",
200+
# source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
201+
# indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
202+
# value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
203+
# ),
204+
param(
205+
test_name="discontinuous_test",
206+
source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32),
207+
indices_tensor=(
208+
torch.tensor([0, 0, 1], dtype=torch.int64),
209+
None,
210+
torch.tensor([0, 0, 1], dtype=torch.int64),
211+
),
212+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
213+
),
198214
param(
199-
test_name="4d_indices_none_none_multiple_idx_broadcast_error",
200-
source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
201-
indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
202-
value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
215+
test_name="discontinuous_test_two",
216+
source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32),
217+
indices_tensor=(
218+
None,
219+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
220+
None,
221+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
222+
),
223+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
224+
),
225+
param(
226+
test_name="continuous_test",
227+
source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32),
228+
indices_tensor=(
229+
None,
230+
None,
231+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
232+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
233+
),
234+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
203235
),
204236
# param(
205237
# test_name="2d_indices_accumulate_True",

0 commit comments

Comments
 (0)