@@ -164,130 +164,42 @@ def scaled_dot_product_attention(
164
164
L = impl .shape .shape (ctx , target , source_ir , name + "_shape_0" , query , 2 )
165
165
if S < 0 :
166
166
S = impl .shape .shape (ctx , target , source_ir , name + "_shape_1" , key , 2 )
167
- if is_causal :
168
- # generate the mask tensor
169
- tril_tensor = tril (
170
- ctx , target , source_ir , name + "_tril" , L , S , sliding_window_size
171
- )
172
167
173
- temp_mask = impl .unary .logical_not (
174
- ctx , target , source_ir , name + "_logical_not" , tril_tensor
175
- )
176
-
177
- # This need_mask determines if we want to use the causal mask or not
178
- # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
179
- # So need_mask will be all False values in this case.
180
- # TODO: Implement more general case where L != 1 and S != L
181
- need_mask = impl .elementwise .eq (ctx , target , source_ir , name + "_eq" , L , S )
182
- temp_mask = impl .elementwise .logical_and (
183
- ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
184
- )
185
- temp_mask_casted = cast_trt_tensor (
186
- ctx , temp_mask , query_dtype , name + "_casted_bool" , target , source_ir
187
- )
188
-
189
- one_minus_temp_mask = impl .elementwise .sub (
190
- ctx ,
191
- target ,
192
- source_ir ,
193
- name + "_one_minus_temp_mask" ,
194
- 1.0 ,
195
- temp_mask_casted ,
196
- )
197
- attn_bias = impl .unary .log (
198
- ctx , target , source_ir , name + "_log" , one_minus_temp_mask
199
- )
200
- scaled_add_attn_bias = impl .elementwise .add (
201
- ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_bias
202
- )
203
- else :
204
- use_if_conditional = False
205
- if not use_if_conditional :
206
- # works in non cache scenario, but in kv cache, got the following error:
207
- # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [ELEMENTWISE]-[aten_ops.scaled_dot_product_attention]-[model.layers.0.self_attn/scaled_dot_product_attention_attn_mask_add]: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 5 != 71 && 5 != 1 && 71 != 1.)
208
- scaled_add_attn_bias = impl .elementwise .add (
209
- ctx , target , source_ir , name + "_attn_mask_add" , mm , attn_mask
210
- )
211
- else :
212
- if_option = "if_conditional_subgraph" # if_conditional_subgraph or if_conditional or if_conditional_input
213
- if if_option == "if_conditional_subgraph" :
214
- # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L46
215
- # if_conditional_subgraph is not working, got the following error:
216
- # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
217
- # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
218
-
219
- need_mask = impl .elementwise .eq (
220
- ctx , target , source_ir , name + "_eq" , L , S
221
- )
222
- # if I do not squeeze, it will throw the error: condition must be a scalar tensor
223
- condition = impl .squeeze .squeeze (
224
- ctx , target , source_ir , name + "_unsqueeze" , need_mask , 0
225
- )
226
- if_layer = ctx .net .add_if_conditional ()
227
- if_layer .set_condition (condition )
228
- cond_input1 = if_layer .add_input (mm )
229
- cond_input2 = if_layer .add_input (attn_mask )
230
-
231
- true_input = impl .elementwise .add (
232
- ctx ,
233
- target ,
234
- source_ir ,
235
- name + "_attn_bias_add" ,
236
- cond_input1 .get_output (0 ),
237
- cond_input2 .get_output (0 ),
238
- )
239
- false_input = cond_input1 .get_output (0 )
240
- output_layer = if_layer .add_output (true_input , false_input )
241
- scaled_add_attn_bias = output_layer .get_output (0 )
242
- elif if_option == "if_conditional_input" :
243
- # reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L17
244
- # if_conditional_input is not working, got the following error:
245
- # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
246
- # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
168
+ # generate the mask tensor
169
+ tril_tensor = tril (
170
+ ctx , target , source_ir , name + "_tril" , L , S , sliding_window_size
171
+ )
247
172
248
- need_mask = impl .elementwise .eq (
249
- ctx , target , source_ir , name + "_eq" , L , S
250
- )
251
- # if I do not squeeze, it will throw the error: condition must be a scalar tensor
252
- condition = impl .squeeze .squeeze (
253
- ctx , target , source_ir , name + "_unsqueeze" , need_mask , 0
254
- )
255
- if_layer = ctx .net .add_if_conditional ()
256
- if_layer .set_condition (condition )
257
- true_input = impl .elementwise .add (
258
- ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_mask
259
- )
260
- false_input = mm
261
- true_cond_input = if_layer .add_input (true_input )
262
- false_cond_input = if_layer .add_input (false_input )
263
- output_layer = if_layer .add_output (
264
- true_cond_input .get_output (0 ), false_cond_input .get_output (0 )
265
- )
266
- scaled_add_attn_bias = output_layer .get_output (0 )
267
- elif if_option == "if_conditional" :
268
- # reference: https://github.com/pytorch/TensorRT/blob/535c6a8341a3258a9c311406a9af50eb3c68c5a6/examples/dynamo/llm/cache_utils.py#L15-L44
269
- # if_conditional is not working, got the following error:
270
- # Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
271
- # ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
173
+ temp_mask = impl .unary .logical_not (
174
+ ctx , target , source_ir , name + "_logical_not" , tril_tensor
175
+ )
272
176
273
- need_mask = impl .elementwise .eq (
274
- ctx , target , source_ir , name + "_eq" , L , S
275
- )
276
- # if I do not squeeze, it will throw the error: condition must be a scalar tensor
277
- condition = impl .squeeze .squeeze (
278
- ctx , target , source_ir , name + "_unsqueeze" , need_mask , 0
279
- )
280
- if_layer = ctx .net .add_if_conditional ()
281
- if_layer .set_condition (condition )
282
- true_input = impl .elementwise .add (
283
- ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_mask
284
- )
285
- false_input = mm
286
- output_layer = if_layer .add_output (
287
- true_input .get_output (0 ), false_input .get_output (0 )
288
- )
289
- scaled_add_attn_bias = output_layer .get_output (0 )
177
+ # This need_mask determines if we want to use the causal mask or not
178
+ # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
179
+ # So need_mask will be all False values in this case.
180
+ # TODO: Implement more general case where L != 1 and S != L
181
+ need_mask = impl .elementwise .eq (ctx , target , source_ir , name + "_eq" , L , S )
182
+ temp_mask = impl .elementwise .logical_and (
183
+ ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
184
+ )
185
+ temp_mask_casted = cast_trt_tensor (
186
+ ctx , temp_mask , query_dtype , name + "_casted_bool" , target , source_ir
187
+ )
290
188
189
+ one_minus_temp_mask = impl .elementwise .sub (
190
+ ctx ,
191
+ target ,
192
+ source_ir ,
193
+ name + "_one_minus_temp_mask" ,
194
+ 1.0 ,
195
+ temp_mask_casted ,
196
+ )
197
+ attn_bias = impl .unary .log (
198
+ ctx , target , source_ir , name + "_log" , one_minus_temp_mask
199
+ )
200
+ scaled_add_attn_bias = impl .elementwise .add (
201
+ ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_bias
202
+ )
291
203
softmax = impl .normalization .softmax (
292
204
ctx , target , source_ir , name + "_softmax" , scaled_add_attn_bias , - 1 , False
293
205
)
0 commit comments