Skip to content

Commit e0474e4

Browse files
committed
support for boolean indices
1 parent 6a27021 commit e0474e4

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,11 @@ def index_dtype_validator(
414414
for ind in index:
415415
if ind is not None:
416416
val = ind.meta.get("val")
417-
if val is not None and val.dtype not in (torch.int32, torch.int64):
417+
if val is not None and val.dtype not in (
418+
torch.int32,
419+
torch.int64,
420+
torch.bool,
421+
):
418422
return False
419423
return True
420424

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,65 @@ def select(
5151
return layer.get_output(0)
5252

5353

54+
def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
55+
if isinstance(tensor, (TRTTensor)):
56+
val = tensor.meta.get("val")
57+
if val is not None and val.dtype is torch.bool:
58+
return True
59+
return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool
60+
61+
62+
def expand_boolean_indices(
63+
ctx: ConversionContext,
64+
target: Target,
65+
source_ir: Optional[SourceIR],
66+
name: str,
67+
input: TRTTensor,
68+
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
69+
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
70+
for i, ind in enumerate(indices):
71+
if ind is not None and is_boolean_tensor(ind):
72+
_LOGGER.debug(
73+
f"Boolean index detected at position {i}, converting with nonzero()"
74+
)
75+
76+
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")
77+
78+
nonzero_layer = ctx.net.add_non_zero(mask_tensor)
79+
set_layer_name(
80+
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
81+
)
82+
nonzero_indices = nonzero_layer.get_output(0)
83+
84+
# nonzero returns shape [N, dims], we need to extract dim i
85+
if len(indices) == 1:
86+
# x[mask] — 1D mask
87+
squeeze_layer = ctx.net.add_shuffle(nonzero_indices)
88+
squeeze_layer.reshape_dims = (-1,)
89+
set_layer_name(
90+
squeeze_layer,
91+
target,
92+
name + f"_bool_nonzero_squeeze_{i}",
93+
source_ir,
94+
)
95+
squeezed_index = squeeze_layer.get_output(0)
96+
ind = squeezed_index
97+
else:
98+
# Advanced multi-axis mask: extract index i from shape [N, D]
99+
gather_axis = 1 # dim index
100+
gather_layer = ctx.net.add_gather(
101+
nonzero_indices,
102+
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
103+
gather_axis,
104+
)
105+
set_layer_name(
106+
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
107+
)
108+
extracted_index = gather_layer.get_output(0)
109+
ind = extracted_index
110+
return indices
111+
112+
54113
def index(
55114
ctx: ConversionContext,
56115
target: Target,
@@ -61,8 +120,6 @@ def index(
61120
) -> TRTTensor:
62121
adv_indx_indices = []
63122
tensor_indices = []
64-
# check if the input is dynamic
65-
dynamic_shape = has_dynamic_shape(input.shape)
66123
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
67124
# If any is not this flag will be set to False
68125
_LOGGER.debug(
@@ -76,6 +133,7 @@ def index(
76133
# here we need to check if all the index are broadcastable
77134
# if no, then we need to broadcast
78135
last_index = None
136+
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
79137
for i, ind in enumerate(indices):
80138
if ind is not None:
81139
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")

0 commit comments

Comments
 (0)