14
14
cast_trt_tensor ,
15
15
get_positive_dim ,
16
16
get_trt_tensor ,
17
- has_dynamic_shape ,
18
17
set_layer_name ,
19
18
to_numpy ,
20
19
)
@@ -52,10 +51,14 @@ def select(
52
51
53
52
54
53
def is_boolean_tensor (tensor : Union [TRTTensor , np .ndarray , torch .Tensor ]) -> bool :
55
- if isinstance (tensor , (TRTTensor )):
54
+ if isinstance (tensor , (torch .Tensor , np .ndarray , TRTTensor )):
55
+ return bool (tensor .dtype == torch .bool )
56
+ # when index is a node
57
+ else :
56
58
val = tensor .meta .get ("val" )
57
59
if val is not None and val .dtype is torch .bool :
58
60
return True
61
+
59
62
return isinstance (tensor , (torch .Tensor , np .ndarray )) and tensor .dtype == torch .bool
60
63
61
64
@@ -67,12 +70,12 @@ def expand_boolean_indices(
67
70
input : TRTTensor ,
68
71
indices : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
69
72
) -> Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]]:
73
+ new_indices = []
70
74
for i , ind in enumerate (indices ):
71
75
if ind is not None and is_boolean_tensor (ind ):
72
76
_LOGGER .debug (
73
77
f"Boolean index detected at position { i } , converting with nonzero()"
74
78
)
75
-
76
79
mask_tensor = get_trt_tensor (ctx , ind , name + f"_bool_mask_{ i } " )
77
80
78
81
nonzero_layer = ctx .net .add_non_zero (mask_tensor )
@@ -93,7 +96,7 @@ def expand_boolean_indices(
93
96
source_ir ,
94
97
)
95
98
squeezed_index = squeeze_layer .get_output (0 )
96
- ind = squeezed_index
99
+ new_indices . append ( squeezed_index )
97
100
else :
98
101
# Advanced multi-axis mask: extract index i from shape [N, D]
99
102
gather_axis = 1 # dim index
@@ -106,8 +109,13 @@ def expand_boolean_indices(
106
109
gather_layer , target , name + f"_bool_nonzero_extract_{ i } " , source_ir
107
110
)
108
111
extracted_index = gather_layer .get_output (0 )
109
- ind = extracted_index
110
- return indices
112
+ squeeze_layer = ctx .net .add_shuffle (extracted_index )
113
+ squeeze_layer .reshape_dims = (- 1 ,)
114
+ squeezed_index = squeeze_layer .get_output (0 )
115
+ new_indices .append (squeezed_index )
116
+ else :
117
+ new_indices .append (ind )
118
+ return new_indices
111
119
112
120
113
121
def index (
@@ -125,6 +133,7 @@ def index(
125
133
_LOGGER .debug (
126
134
"Determining whether aten.index constant-index optimization can be invoked"
127
135
)
136
+ indices = expand_boolean_indices (ctx , target , source_ir , name , input , indices )
128
137
is_numpy = all (
129
138
isinstance (ind , (torch .Tensor , np .ndarray ))
130
139
for ind in indices
@@ -133,7 +142,6 @@ def index(
133
142
# here we need to check if all the index are broadcastable
134
143
# if no, then we need to broadcast
135
144
last_index = None
136
- indices = expand_boolean_indices (ctx , target , source_ir , name , input , indices )
137
145
for i , ind in enumerate (indices ):
138
146
if ind is not None :
139
147
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
0 commit comments