Skip to content

Commit 1b23905

Browse files
authored
Index converter dynamic cases fix (#3694)
1 parent ebfade7 commit 1b23905

File tree

3 files changed

+146
-5
lines changed

3 files changed

+146
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,11 @@ def index_dtype_validator(
415415
for ind in index:
416416
if ind is not None:
417417
val = ind.meta.get("val")
418-
if val is not None and val.dtype not in (torch.int32, torch.int64):
418+
if val is not None and val.dtype not in (
419+
torch.int32,
420+
torch.int64,
421+
torch.bool,
422+
):
419423
return False
420424
return True
421425

@@ -424,6 +428,7 @@ def index_dtype_validator(
424428
torch.ops.aten.index.Tensor,
425429
capability_validator=index_dtype_validator,
426430
supports_dynamic_shapes=True,
431+
requires_output_allocator=True,
427432
)
428433
@enforce_tensor_types(
429434
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
cast_trt_tensor,
1515
get_positive_dim,
1616
get_trt_tensor,
17-
has_dynamic_shape,
1817
set_layer_name,
1918
to_numpy,
2019
)
@@ -51,6 +50,77 @@ def select(
5150
return layer.get_output(0)
5251

5352

53+
def is_boolean_tensor(
54+
tensor: Union[TRTTensor, np.ndarray, torch.Tensor, torch.fx.Node],
55+
) -> bool:
56+
if isinstance(tensor, torch.Tensor):
57+
return bool(tensor.dtype == torch.bool)
58+
elif isinstance(tensor, np.ndarray):
59+
return bool(tensor.dtype == np.bool_)
60+
elif isinstance(tensor, TRTTensor):
61+
return bool(tensor.dtype == trt.DataType.BOOL)
62+
# when index is a node
63+
else:
64+
val = tensor.meta.get("val")
65+
if val is not None and val.dtype is torch.bool:
66+
return True
67+
68+
return False
69+
70+
71+
def expand_boolean_indices(
72+
ctx: ConversionContext,
73+
target: Target,
74+
source_ir: Optional[SourceIR],
75+
name: str,
76+
input: TRTTensor,
77+
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
78+
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
79+
new_indices = []
80+
for i, ind in enumerate(indices):
81+
if ind is not None and is_boolean_tensor(ind):
82+
_LOGGER.debug(
83+
f"Boolean index detected at position {i}, converting with nonzero()"
84+
)
85+
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")
86+
87+
nonzero_layer = ctx.net.add_non_zero(mask_tensor)
88+
set_layer_name(
89+
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
90+
)
91+
nonzero_indices = nonzero_layer.get_output(0)
92+
93+
# nonzero returns shape [N, dims], we need to extract dim i
94+
if len(indices) == 1:
95+
# x[mask] — 1D mask
96+
to_squeeze = nonzero_indices
97+
else:
98+
# Advanced multi-axis mask: extract index i from shape [N, D]
99+
gather_axis = 1 # dim index
100+
gather_layer = ctx.net.add_gather(
101+
nonzero_indices,
102+
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
103+
gather_axis,
104+
)
105+
set_layer_name(
106+
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
107+
)
108+
to_squeeze = gather_layer.get_output(0)
109+
squeeze_layer = ctx.net.add_shuffle(to_squeeze)
110+
squeeze_layer.reshape_dims = (-1,)
111+
set_layer_name(
112+
squeeze_layer,
113+
target,
114+
name + f"_bool_mask_squeeze_{i}",
115+
source_ir,
116+
)
117+
squeezed_index = squeeze_layer.get_output(0)
118+
new_indices.append(squeezed_index)
119+
else:
120+
new_indices.append(ind)
121+
return new_indices
122+
123+
54124
def index(
55125
ctx: ConversionContext,
56126
target: Target,
@@ -61,13 +131,12 @@ def index(
61131
) -> TRTTensor:
62132
adv_indx_indices = []
63133
tensor_indices = []
64-
# check if the input is dynamic
65-
dynamic_shape = has_dynamic_shape(input.shape)
66134
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
67135
# If any is not this flag will be set to False
68136
_LOGGER.debug(
69137
"Determining whether aten.index constant-index optimization can be invoked"
70138
)
139+
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
71140
is_numpy = all(
72141
isinstance(ind, (torch.Tensor, np.ndarray))
73142
for ind in indices

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase):
7171
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
7272
torch.randn(2, 4, 4, 2),
7373
),
74+
(
75+
"mask_index_three_dim",
76+
[None, torch.tensor([True, False]), None],
77+
torch.randn(2, 2, 2),
78+
),
79+
(
80+
"mask_index_two_dim",
81+
[torch.tensor([True, False])],
82+
torch.randn(2, 2),
83+
),
84+
(
85+
# covers multi axis and discontinuous indices
86+
"mask_index_multi_axis",
87+
[
88+
None,
89+
torch.tensor([True, False]), # axis 1
90+
None,
91+
torch.tensor([True, False]), # axis 3
92+
],
93+
torch.randn(2, 2, 2, 2),
94+
),
7495
]
7596
)
7697
def test_index_constant(self, _, index, input):
@@ -104,6 +125,17 @@ def forward(self, x, index0):
104125
[input, index0],
105126
)
106127

128+
def test_index_zero_two_dim_ITensor_mask(self):
129+
class TestModule(nn.Module):
130+
def forward(self, x, index0):
131+
indices = [None, index0]
132+
out = torch.ops.aten.index.Tensor(x, indices)
133+
return out
134+
135+
input = torch.randn(2, 2)
136+
index0 = torch.tensor([True, False])
137+
self.run_test(TestModule(), [input, index0], enable_passes=True)
138+
107139
def test_index_zero_index_three_dim_ITensor(self):
108140
class TestModule(nn.Module):
109141
def forward(self, x, index0):
@@ -116,6 +148,17 @@ def forward(self, x, index0):
116148
index0 = index0.to(torch.int32)
117149
self.run_test(TestModule(), [input, index0])
118150

151+
def test_index_zero_index_three_dim_mask_ITensor(self):
152+
class TestModule(nn.Module):
153+
def forward(self, x, index0):
154+
indices = [None, index0, None]
155+
out = torch.ops.aten.index.Tensor(x, indices)
156+
return out
157+
158+
input = torch.randn(2, 2, 2)
159+
index0 = torch.tensor([True, False])
160+
self.run_test(TestModule(), [input, index0])
161+
119162

120163
class TestIndexDynamicConstantConverter(DispatchTestCase):
121164
@parameterized.expand(
@@ -168,7 +211,31 @@ def forward(self, input):
168211
dtype=torch.float32,
169212
),
170213
]
171-
self.run_test_with_dynamic_shape(TestModule(), input_specs)
214+
self.run_test_with_dynamic_shape(
215+
TestModule(), input_specs, use_dynamo_tracer=True
216+
)
217+
218+
219+
class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
220+
def test_index_input_non_dynamic_index_dynamic(self):
221+
class TestIndexWithRuntimeIndex(torch.nn.Module):
222+
def forward(self, x):
223+
mask = x > 0
224+
idx = torch.nonzero(mask, as_tuple=True)
225+
return torch.ops.aten.index.Tensor(x, idx)
226+
227+
input_specs = [
228+
Input(
229+
min_shape=(2, 2),
230+
opt_shape=(2, 2),
231+
max_shape=(8, 8),
232+
dtype=torch.float32,
233+
),
234+
]
235+
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
236+
self.run_test_with_dynamic_shape(
237+
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
238+
)
172239

173240

174241
if __name__ == "__main__":

0 commit comments

Comments
 (0)