@@ -53,6 +53,67 @@ def select(
53
53
return layer .get_output (0 )
54
54
55
55
56
+ def is_boolean_tensor (tensor : Union [TRTTensor , np .ndarray , torch .Tensor ]) -> bool :
57
+ if isinstance (tensor , (TRTTensor )):
58
+ val = tensor .meta .get ("val" )
59
+ if val is not None and val .dtype is torch .bool :
60
+ return True
61
+ return isinstance (tensor , (torch .Tensor , np .ndarray )) and tensor .dtype == torch .bool
62
+
63
+
64
+ def expand_boolean_indices (
65
+ ctx : ConversionContext ,
66
+ target : Target ,
67
+ source_ir : Optional [SourceIR ],
68
+ name : str ,
69
+ input : TRTTensor ,
70
+ indices : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
71
+ ) -> Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]]:
72
+ for i , ind in enumerate (indices ):
73
+ if ind is not None and is_boolean_tensor (ind ):
74
+ _LOGGER .debug (
75
+ f"Boolean index detected at position { i } , converting with nonzero()"
76
+ )
77
+
78
+ # Convert to TRT tensor if not already
79
+ mask_tensor = get_trt_tensor (ctx , ind , name + f"_bool_mask_{ i } " )
80
+
81
+ # Apply nonzero
82
+ nonzero_layer = ctx .net .add_non_zero (mask_tensor )
83
+ set_layer_name (
84
+ nonzero_layer , target , name + f"_bool_nonzero_{ i } " , source_ir
85
+ )
86
+ nonzero_indices = nonzero_layer .get_output (0 )
87
+
88
+ # nonzero returns shape [N, dims], we need to extract dim i
89
+ if len (indices ) == 1 :
90
+ # x[mask] — 1D mask, squeeze is safe
91
+ squeeze_layer = ctx .net .add_shuffle (nonzero_indices )
92
+ squeeze_layer .reshape_dims = (- 1 ,)
93
+ set_layer_name (
94
+ squeeze_layer ,
95
+ target ,
96
+ name + f"_bool_nonzero_squeeze_{ i } " ,
97
+ source_ir ,
98
+ )
99
+ squeezed_index = squeeze_layer .get_output (0 )
100
+ ind = squeezed_index
101
+ else :
102
+ # Advanced multi-axis mask: extract index i from shape [N, D]
103
+ gather_axis = 1 # dim index
104
+ gather_layer = ctx .net .add_gather (
105
+ nonzero_indices ,
106
+ get_trt_tensor (ctx , i , name + f"_dim_index_{ i } " ),
107
+ gather_axis ,
108
+ )
109
+ set_layer_name (
110
+ gather_layer , target , name + f"_bool_nonzero_extract_{ i } " , source_ir
111
+ )
112
+ extracted_index = gather_layer .get_output (0 )
113
+ ind = extracted_index
114
+ return indices
115
+
116
+
56
117
def index (
57
118
ctx : ConversionContext ,
58
119
target : Target ,
@@ -63,8 +124,6 @@ def index(
63
124
) -> TRTTensor :
64
125
adv_indx_indices = []
65
126
tensor_indices = []
66
- # check if the input is dynamic
67
- dynamic_shape = has_dynamic_shape (input .shape )
68
127
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
69
128
# If any is not this flag will be set to False
70
129
_LOGGER .debug (
@@ -78,6 +137,7 @@ def index(
78
137
# here we need to check if all the index are broadcastable
79
138
# if no, then we need to broadcast
80
139
last_index = None
140
+ indices = expand_boolean_indices (ctx , target , source_ir , name , input , indices )
81
141
for i , ind in enumerate (indices ):
82
142
if ind is not None :
83
143
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
0 commit comments