Skip to content

Commit f97d164

Browse files
committed
mask test cases and correction
1 parent e0474e4 commit f97d164

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def index_dtype_validator(
427427
torch.ops.aten.index.Tensor,
428428
capability_validator=index_dtype_validator,
429429
supports_dynamic_shapes=True,
430+
requires_output_allocator=True,
430431
)
431432
@enforce_tensor_types(
432433
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
cast_trt_tensor,
1515
get_positive_dim,
1616
get_trt_tensor,
17-
has_dynamic_shape,
1817
set_layer_name,
1918
to_numpy,
2019
)
@@ -52,10 +51,14 @@ def select(
5251

5352

5453
def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
55-
if isinstance(tensor, (TRTTensor)):
54+
if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)):
55+
return bool(tensor.dtype == torch.bool)
56+
# when index is a node
57+
else:
5658
val = tensor.meta.get("val")
5759
if val is not None and val.dtype is torch.bool:
5860
return True
61+
5962
return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool
6063

6164

@@ -67,12 +70,12 @@ def expand_boolean_indices(
6770
input: TRTTensor,
6871
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
6972
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
73+
new_indices = []
7074
for i, ind in enumerate(indices):
7175
if ind is not None and is_boolean_tensor(ind):
7276
_LOGGER.debug(
7377
f"Boolean index detected at position {i}, converting with nonzero()"
7478
)
75-
7679
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")
7780

7881
nonzero_layer = ctx.net.add_non_zero(mask_tensor)
@@ -93,7 +96,7 @@ def expand_boolean_indices(
9396
source_ir,
9497
)
9598
squeezed_index = squeeze_layer.get_output(0)
96-
ind = squeezed_index
99+
new_indices.append(squeezed_index)
97100
else:
98101
# Advanced multi-axis mask: extract index i from shape [N, D]
99102
gather_axis = 1 # dim index
@@ -106,8 +109,13 @@ def expand_boolean_indices(
106109
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
107110
)
108111
extracted_index = gather_layer.get_output(0)
109-
ind = extracted_index
110-
return indices
112+
squeeze_layer = ctx.net.add_shuffle(extracted_index)
113+
squeeze_layer.reshape_dims = (-1,)
114+
squeezed_index = squeeze_layer.get_output(0)
115+
new_indices.append(squeezed_index)
116+
else:
117+
new_indices.append(ind)
118+
return new_indices
111119

112120

113121
def index(
@@ -125,6 +133,7 @@ def index(
125133
_LOGGER.debug(
126134
"Determining whether aten.index constant-index optimization can be invoked"
127135
)
136+
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
128137
is_numpy = all(
129138
isinstance(ind, (torch.Tensor, np.ndarray))
130139
for ind in indices
@@ -133,7 +142,6 @@ def index(
133142
# here we need to check if all the index are broadcastable
134143
# if no, then we need to broadcast
135144
last_index = None
136-
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
137145
for i, ind in enumerate(indices):
138146
if ind is not None:
139147
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ class TestIndexConstantConverter(DispatchTestCase):
7171
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
7272
torch.randn(2, 4, 4, 2),
7373
),
74+
(
75+
"mask_index_three_dim",
76+
[None, torch.tensor([True, False]), None],
77+
torch.randn(2, 2, 2),
78+
),
79+
(
80+
"mask_index_two_dim",
81+
[torch.tensor([True, False])],
82+
torch.randn(2, 2),
83+
),
7484
]
7585
)
7686
def test_index_constant(self, _, index, input):

0 commit comments

Comments
 (0)