Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,11 @@ def index_dtype_validator(
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype not in (torch.int32, torch.int64):
if val is not None and val.dtype not in (
torch.int32,
torch.int64,
torch.bool,
):
return False
return True

Expand All @@ -423,6 +427,7 @@ def index_dtype_validator(
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
@enforce_tensor_types(
{
Expand Down
69 changes: 66 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
Expand Down Expand Up @@ -51,6 +50,71 @@ def select(
return layer.get_output(0)


def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)):
return bool(tensor.dtype == torch.bool)
# when index is a node
else:
val = tensor.meta.get("val")
if val is not None and val.dtype is torch.bool:
return True

return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool


def expand_boolean_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
new_indices = []
for i, ind in enumerate(indices):
if ind is not None and is_boolean_tensor(ind):
_LOGGER.debug(
f"Boolean index detected at position {i}, converting with nonzero()"
)
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")

nonzero_layer = ctx.net.add_non_zero(mask_tensor)
set_layer_name(
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
)
nonzero_indices = nonzero_layer.get_output(0)

# nonzero returns shape [N, dims], we need to extract dim i
if len(indices) == 1:
# x[mask] — 1D mask
to_squeeze = nonzero_indices
else:
# Advanced multi-axis mask: extract index i from shape [N, D]
gather_axis = 1 # dim index
gather_layer = ctx.net.add_gather(
nonzero_indices,
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
gather_axis,
)
set_layer_name(
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
)
to_squeeze = gather_layer.get_output(0)
squeeze_layer = ctx.net.add_shuffle(to_squeeze)
squeeze_layer.reshape_dims = (-1,)
set_layer_name(
squeeze_layer,
target,
name + f"_bool_mask_squeeze_{i}",
source_ir,
)
squeezed_index = squeeze_layer.get_output(0)
new_indices.append(squeezed_index)
else:
new_indices.append(ind)
return new_indices


def index(
ctx: ConversionContext,
target: Target,
Expand All @@ -61,13 +125,12 @@ def index(
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
"Determining whether aten.index constant-index optimization can be invoked"
)
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray))
for ind in indices
Expand Down
47 changes: 46 additions & 1 deletion tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase):
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
torch.randn(2, 4, 4, 2),
),
(
"mask_index_three_dim",
[None, torch.tensor([True, False]), None],
torch.randn(2, 2, 2),
),
(
"mask_index_two_dim",
[torch.tensor([True, False])],
torch.randn(2, 2),
),
(
# covers multi axis and discontinuous indices
"mask_index_multi_axis",
[
None,
torch.tensor([True, False]), # axis 1
None,
torch.tensor([True, False]), # axis 3
],
torch.randn(2, 2, 2, 2),
),
]
)
def test_index_constant(self, _, index, input):
Expand Down Expand Up @@ -168,7 +189,31 @@ def forward(self, input):
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)


class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
def test_index_input_non_dynamic_index_dynamic(self):
class TestIndexWithRuntimeIndex(torch.nn.Module):
def forward(self, x):
mask = x > 0
idx = torch.nonzero(mask, as_tuple=True)
return torch.ops.aten.index.Tensor(x, idx)

input_specs = [
Input(
min_shape=(2, 2),
opt_shape=(2, 2),
max_shape=(8, 8),
dtype=torch.float32,
),
]
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
self.run_test_with_dynamic_shape(
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
)


if __name__ == "__main__":
Expand Down
Loading