Skip to content

Commit 1834500

Browse files
committed
support for boolean indices
1 parent a14fc4c commit 1834500

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,11 @@ def index_dtype_validator(
386386
for ind in index:
387387
if ind is not None:
388388
val = ind.meta.get("val")
389-
if val is not None and val.dtype not in (torch.int32, torch.int64):
389+
if val is not None and val.dtype not in (
390+
torch.int32,
391+
torch.int64,
392+
torch.bool,
393+
):
390394
return False
391395
return True
392396

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,67 @@ def select(
5353
return layer.get_output(0)
5454

5555

56+
def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
57+
if isinstance(tensor, (TRTTensor)):
58+
val = tensor.meta.get("val")
59+
if val is not None and val.dtype is torch.bool:
60+
return True
61+
return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool
62+
63+
64+
def expand_boolean_indices(
65+
ctx: ConversionContext,
66+
target: Target,
67+
source_ir: Optional[SourceIR],
68+
name: str,
69+
input: TRTTensor,
70+
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
71+
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
72+
for i, ind in enumerate(indices):
73+
if ind is not None and is_boolean_tensor(ind):
74+
_LOGGER.debug(
75+
f"Boolean index detected at position {i}, converting with nonzero()"
76+
)
77+
78+
# Convert to TRT tensor if not already
79+
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")
80+
81+
# Apply nonzero
82+
nonzero_layer = ctx.net.add_non_zero(mask_tensor)
83+
set_layer_name(
84+
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
85+
)
86+
nonzero_indices = nonzero_layer.get_output(0)
87+
88+
# nonzero returns shape [N, dims], we need to extract dim i
89+
if len(indices) == 1:
90+
# x[mask] — 1D mask, squeeze is safe
91+
squeeze_layer = ctx.net.add_shuffle(nonzero_indices)
92+
squeeze_layer.reshape_dims = (-1,)
93+
set_layer_name(
94+
squeeze_layer,
95+
target,
96+
name + f"_bool_nonzero_squeeze_{i}",
97+
source_ir,
98+
)
99+
squeezed_index = squeeze_layer.get_output(0)
100+
ind = squeezed_index
101+
else:
102+
# Advanced multi-axis mask: extract index i from shape [N, D]
103+
gather_axis = 1 # dim index
104+
gather_layer = ctx.net.add_gather(
105+
nonzero_indices,
106+
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
107+
gather_axis,
108+
)
109+
set_layer_name(
110+
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
111+
)
112+
extracted_index = gather_layer.get_output(0)
113+
ind = extracted_index
114+
return indices
115+
116+
56117
def index(
57118
ctx: ConversionContext,
58119
target: Target,
@@ -63,8 +124,6 @@ def index(
63124
) -> TRTTensor:
64125
adv_indx_indices = []
65126
tensor_indices = []
66-
# check if the input is dynamic
67-
dynamic_shape = has_dynamic_shape(input.shape)
68127
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
69128
# If any is not this flag will be set to False
70129
_LOGGER.debug(
@@ -78,6 +137,7 @@ def index(
78137
# here we need to check if all the index are broadcastable
79138
# if no, then we need to broadcast
80139
last_index = None
140+
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
81141
for i, ind in enumerate(indices):
82142
if ind is not None:
83143
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")

0 commit comments

Comments
 (0)