@@ -149,7 +149,12 @@ def embedding_bag_with_ITensor_offsets(
149
149
mode : int ,
150
150
include_last_offset : bool ,
151
151
) -> Tuple [TRTTensor , TRTTensor , TRTTensor , TRTTensor ]:
152
- len_embed = embed .shape [0 ]
152
+ # prepare some tensors for future use
153
+ constant_0 = get_trt_tensor (ctx , 0 , f"{ name } _constant_tensor_0" )
154
+ constant_1 = get_trt_tensor (ctx , 1 , f"{ name } _constant_tensor_1" )
155
+
156
+ embed_shape = ctx .net .add_shape (embed ).get_output (0 )
157
+ len_embed = ctx .net .add_gather (embed_shape , constant_0 , 0 ).get_output (0 )
153
158
154
159
if include_last_offset :
155
160
# modify the last index of offsets to the end index
@@ -166,32 +171,38 @@ def embedding_bag_with_ITensor_offsets(
166
171
# create a placeholder tensor, whose shape is the same as an embedding
167
172
# if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros
168
173
# if mode is 2 (max), the placeholder tensor is filled with negative infinity
169
- placeholder_tensor = (
170
- get_trt_tensor (
174
+ placeholder_tensor = impl .elementwise .mul (
175
+ ctx , target , source_ir , f"{ name } _zero_tensors" , embed , 0
176
+ )
177
+ zero_tensor = ctx .net .add_gather (placeholder_tensor , constant_0 , 0 ).get_output (0 )
178
+ zero_tensor = impl .shuffle .reshape (
179
+ ctx , target , source_ir , f"{ name } _reshape_zero_tensor" , zero_tensor , (- 1 ,)
180
+ )
181
+
182
+ if mode == 2 :
183
+ placeholder_tensor = impl .elementwise .add (
171
184
ctx ,
172
- np .full (embed .shape , - np .inf , dtype = np .float32 ),
185
+ target ,
186
+ source_ir ,
173
187
f"{ name } _negative_inf_tensor" ,
188
+ placeholder_tensor ,
189
+ - np .inf ,
174
190
)
175
- if mode == 2
176
- else get_trt_tensor (
177
- ctx , np .zeros (embed .shape , dtype = np .float32 ), f"{ name } _zero_tensors"
178
- )
179
- )
180
-
181
- # prepare some tensors for future use
182
- zero_tensor = get_trt_tensor (
183
- ctx , np .zeros ((embed .shape [1 ],), dtype = np .float32 ), f"{ name } _zero_tensor"
184
- )
185
- constant_0 = get_trt_tensor (ctx , 0 , f"{ name } _constant_tensor_0" )
186
- constant_1 = get_trt_tensor (ctx , 1 , f"{ name } _constant_tensor_1" )
187
191
188
192
# Use two for loops to calculate the embedding of each bag
189
193
###### Outer loop: traverse offsets ######
190
194
loop1 = ctx .net .add_loop ()
191
- trip_limit1 = ctx .net .add_constant (
192
- shape = (),
193
- weights = trt .Weights (np .array ([offsets .shape [0 ] - 1 ], dtype = np .int32 )),
194
- ).get_output (0 )
195
+
196
+ offsets_shape = ctx .net .add_shape (offsets ).get_output (0 )
197
+ offsets_size = ctx .net .add_gather (offsets_shape , constant_0 , 0 ).get_output (0 )
198
+ trip_limit1 = impl .elementwise .sub (
199
+ ctx , target , source_ir , f"{ name } _trip_limit1" , offsets_size , 1
200
+ )
201
+ # change to 0d tensor
202
+ trip_limit1 = impl .shuffle .reshape (
203
+ ctx , target , source_ir , f"{ name } _reshape_trip_limit1" , trip_limit1 , ()
204
+ )
205
+
195
206
loop1 .add_trip_limit (trip_limit1 , trt .TripLimit .COUNT )
196
207
197
208
rec1_i_tensor = loop1 .add_recurrence (constant_1 )
@@ -210,9 +221,9 @@ def embedding_bag_with_ITensor_offsets(
210
221
211
222
###### Inner loop: traverse indices ######
212
223
loop2 = ctx .net .add_loop ()
213
- trip_limit2 = ctx . net . add_constant (
214
- shape = (), weights = trt . Weights ( np . array ([ len_embed ], dtype = np . int32 ) )
215
- ). get_output ( 0 )
224
+ trip_limit2 = impl . shuffle . reshape (
225
+ ctx , target , source_ir , f" { name } _reshape_trip_limit2" , len_embed , ( )
226
+ )
216
227
loop2 .add_trip_limit (trip_limit2 , trt .TripLimit .COUNT )
217
228
rec2_j_tensor = loop2 .add_recurrence (constant_0 )
218
229
set_layer_name (rec2_j_tensor , target , f"{ name } _rec2_j_tensor" , source_ir )
0 commit comments