Skip to content

Commit fa53986

Browse files
committed
Fixed a typo in the converter. Covered the discontinuous tests
1 parent 651f4c8 commit fa53986

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def index_put_converter(
644644
for dim in range(rank):
645645
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
646646
if dim in I:
647-
idx_tensor = I_combined[i]
647+
idx_tensor = I_combined[i_idx]
648648
ii_list.append(idx_tensor)
649649
i_idx += 1
650650
else:
@@ -702,7 +702,12 @@ def index_put_converter(
702702
)
703703
else: # Non-scalar case
704704
values_shape = list(values.shape)
705-
if K > 0 and N in values_shape:
705+
if (
706+
K > 0
707+
and N in values_shape
708+
and (len(F) > 1 and max(F) - min(F) + 1 == len(F))
709+
):
710+
# Continuous case
706711
n_idx = values_shape.index(N)
707712
permute_order = [n_idx] + [
708713
i for i in range(len(values_shape)) if i != n_idx
@@ -748,6 +753,7 @@ def index_put_converter(
748753
tuple(broadcast_shape),
749754
)
750755
else:
756+
# Discontinuous case
751757
values_shape_padded = [1] * (
752758
len(expected_shape) - len(values.shape)
753759
) + list(values.shape)
@@ -759,20 +765,20 @@ def index_put_converter(
759765
raise ValueError(
760766
f"Cannot broadcast {values.shape} to {expected_shape}"
761767
)
762-
values_reshaped = impl.shuffle.reshape(
763-
ctx,
764-
target,
765-
source_ir,
766-
f"{name}_reshape_values",
767-
values,
768-
tuple(broadcast_shape),
769-
)
768+
# values_reshaped = impl.shuffle.reshape(
769+
# ctx,
770+
# target,
771+
# source_ir,
772+
# f"{name}_reshape_values",
773+
# values,
774+
# tuple(broadcast_shape),
775+
# )
770776
values_expanded = impl.slice.expand(
771777
ctx,
772778
target,
773779
source_ir,
774780
f"{name}_expand_values",
775-
values_reshaped,
781+
values,
776782
expected_shape,
777783
)
778784

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,43 @@ class TestIndexPutConverter(DispatchTestCase):
195195
dtype=torch.int32,
196196
),
197197
),
198+
# param(
199+
# test_name="4d_indices_none_none_multiple_idx_broadcast_error",
200+
# source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
201+
# indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
202+
# value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
203+
# ),
204+
param(
205+
test_name="discontinuous_test",
206+
source_tensor=torch.zeros([2, 4, 4], dtype=torch.float32),
207+
indices_tensor=(
208+
torch.tensor([0, 0, 1], dtype=torch.int64),
209+
None,
210+
torch.tensor([0, 0, 1], dtype=torch.int64),
211+
),
212+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
213+
),
198214
param(
199-
test_name="4d_indices_none_none_multiple_idx_broadcast_error",
200-
source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32),
201-
indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)),
202-
value_tensor=torch.randn([2, 3, 3], dtype=torch.float32),
215+
test_name="discontinuous_test_two",
216+
source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32),
217+
indices_tensor=(
218+
None,
219+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
220+
None,
221+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
222+
),
223+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
224+
),
225+
param(
226+
test_name="continuous_test",
227+
source_tensor=torch.zeros([2, 4, 4, 2], dtype=torch.float32),
228+
indices_tensor=(
229+
None,
230+
None,
231+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
232+
torch.tensor([0, 0, 1, 1], dtype=torch.int64),
233+
),
234+
value_tensor=torch.tensor([2, 3, 3, 4], dtype=torch.float32),
203235
),
204236
# param(
205237
# test_name="2d_indices_accumulate_True",

0 commit comments

Comments
 (0)