@@ -51,6 +51,65 @@ def select(
51
51
return layer .get_output (0 )
52
52
53
53
54
+ def is_boolean_tensor (tensor : Union [TRTTensor , np .ndarray , torch .Tensor ]) -> bool :
55
+ if isinstance (tensor , (TRTTensor )):
56
+ val = tensor .meta .get ("val" )
57
+ if val is not None and val .dtype is torch .bool :
58
+ return True
59
+ return isinstance (tensor , (torch .Tensor , np .ndarray )) and tensor .dtype == torch .bool
60
+
61
+
62
+ def expand_boolean_indices (
63
+ ctx : ConversionContext ,
64
+ target : Target ,
65
+ source_ir : Optional [SourceIR ],
66
+ name : str ,
67
+ input : TRTTensor ,
68
+ indices : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
69
+ ) -> Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]]:
70
+ for i , ind in enumerate (indices ):
71
+ if ind is not None and is_boolean_tensor (ind ):
72
+ _LOGGER .debug (
73
+ f"Boolean index detected at position { i } , converting with nonzero()"
74
+ )
75
+
76
+ mask_tensor = get_trt_tensor (ctx , ind , name + f"_bool_mask_{ i } " )
77
+
78
+ nonzero_layer = ctx .net .add_non_zero (mask_tensor )
79
+ set_layer_name (
80
+ nonzero_layer , target , name + f"_bool_nonzero_{ i } " , source_ir
81
+ )
82
+ nonzero_indices = nonzero_layer .get_output (0 )
83
+
84
+ # nonzero returns shape [N, dims], we need to extract dim i
85
+ if len (indices ) == 1 :
86
+ # x[mask] — 1D mask
87
+ squeeze_layer = ctx .net .add_shuffle (nonzero_indices )
88
+ squeeze_layer .reshape_dims = (- 1 ,)
89
+ set_layer_name (
90
+ squeeze_layer ,
91
+ target ,
92
+ name + f"_bool_nonzero_squeeze_{ i } " ,
93
+ source_ir ,
94
+ )
95
+ squeezed_index = squeeze_layer .get_output (0 )
96
+ ind = squeezed_index
97
+ else :
98
+ # Advanced multi-axis mask: extract index i from shape [N, D]
99
+ gather_axis = 1 # dim index
100
+ gather_layer = ctx .net .add_gather (
101
+ nonzero_indices ,
102
+ get_trt_tensor (ctx , i , name + f"_dim_index_{ i } " ),
103
+ gather_axis ,
104
+ )
105
+ set_layer_name (
106
+ gather_layer , target , name + f"_bool_nonzero_extract_{ i } " , source_ir
107
+ )
108
+ extracted_index = gather_layer .get_output (0 )
109
+ ind = extracted_index
110
+ return indices
111
+
112
+
54
113
def index (
55
114
ctx : ConversionContext ,
56
115
target : Target ,
@@ -61,8 +120,6 @@ def index(
61
120
) -> TRTTensor :
62
121
adv_indx_indices = []
63
122
tensor_indices = []
64
- # check if the input is dynamic
65
- dynamic_shape = has_dynamic_shape (input .shape )
66
123
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
67
124
# If any is not this flag will be set to False
68
125
_LOGGER .debug (
@@ -76,6 +133,7 @@ def index(
76
133
# here we need to check if all the index are broadcastable
77
134
# if no, then we need to broadcast
78
135
last_index = None
136
+ indices = expand_boolean_indices (ctx , target , source_ir , name , input , indices )
79
137
for i , ind in enumerate (indices ):
80
138
if ind is not None :
81
139
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
0 commit comments