@@ -162,11 +162,7 @@ def scaled_dot_product_attention(
162
162
if S < 0 :
163
163
S = impl .shape .shape (ctx , target , source_ir , name + "_shape_1" , key , 2 )
164
164
# generate the mask tensor
165
- if is_causal :
166
- tril_tensor = tril (ctx , target , source_ir , name + "_tril" , L , S )
167
- else :
168
- # TODO: lan to figure out why attn_mask passed in from transformers is not working
169
- # tried both 2d and 4d, but both are not working
165
+ if not is_causal :
170
166
assert len (attn_mask .shape ) in [
171
167
2 ,
172
168
4 ,
@@ -183,48 +179,56 @@ def scaled_dot_product_attention(
183
179
attn_mask = impl .squeeze .squeeze (
184
180
ctx , target , source_ir , name + "_squeeze" , attn_mask , (0 , 1 )
185
181
)
186
- tril_tensor = attn_mask
187
-
188
- # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
189
- attn_bias_via_where = True
190
- if attn_bias_via_where :
191
- attn_bias = impl .condition .where (
192
- ctx ,
193
- target ,
194
- source_ir ,
195
- name + "_where" ,
196
- torch .tensor (0.0 , dtype = torch .float32 ).cuda (),
197
- torch .tensor (- float ("inf" ), dtype = torch .float32 ).cuda (),
198
- tril_tensor ,
199
- )
182
+ attn_bias = attn_mask
200
183
else :
201
- temp_mask = impl .unary .logical_not (
202
- ctx , target , source_ir , name + "_logical_not" , tril_tensor
203
- )
184
+ tril_tensor = tril (ctx , target , source_ir , name + "_tril" , L , S )
185
+ # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
186
+ attn_bias_via_where = True
187
+ if attn_bias_via_where :
188
+ attn_bias = impl .condition .where (
189
+ ctx ,
190
+ target ,
191
+ source_ir ,
192
+ name + "_where" ,
193
+ torch .tensor (0.0 , dtype = torch .float32 ).cuda (),
194
+ torch .tensor (- float ("inf" ), dtype = torch .float32 ).cuda (),
195
+ tril_tensor ,
196
+ )
197
+ else :
198
+ temp_mask = impl .unary .logical_not (
199
+ ctx , target , source_ir , name + "_logical_not" , tril_tensor
200
+ )
204
201
205
- # This need_mask determines if we want to use the causal mask or not
206
- # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
207
- # So need_mask will be all False values in this case.
208
- # TODO: Implement more general case where L != 1 and S != L
209
- need_mask = impl .elementwise .eq (ctx , target , source_ir , name + "_eq" , L , S )
210
- temp_mask = impl .elementwise .logical_and (
211
- ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
212
- )
213
- temp_mask_casted = cast_trt_tensor (
214
- ctx , temp_mask , query_dtype , name + "_casted_bool" , target , source_ir
215
- )
202
+ # This need_mask determines if we want to use the causal mask or not
203
+ # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
204
+ # So need_mask will be all False values in this case.
205
+ # TODO: Implement more general case where L != 1 and S != L
206
+ need_mask = impl .elementwise .eq (
207
+ ctx , target , source_ir , name + "_eq" , L , S
208
+ )
209
+ temp_mask = impl .elementwise .logical_and (
210
+ ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
211
+ )
212
+ temp_mask_casted = cast_trt_tensor (
213
+ ctx ,
214
+ temp_mask ,
215
+ query_dtype ,
216
+ name + "_casted_bool" ,
217
+ target ,
218
+ source_ir ,
219
+ )
216
220
217
- one_minus_temp_mask = impl .elementwise .sub (
218
- ctx ,
219
- target ,
220
- source_ir ,
221
- name + "_one_minus_temp_mask" ,
222
- 1.0 ,
223
- temp_mask_casted ,
224
- )
225
- attn_bias = impl .unary .log (
226
- ctx , target , source_ir , name + "_log" , one_minus_temp_mask
227
- )
221
+ one_minus_temp_mask = impl .elementwise .sub (
222
+ ctx ,
223
+ target ,
224
+ source_ir ,
225
+ name + "_one_minus_temp_mask" ,
226
+ 1.0 ,
227
+ temp_mask_casted ,
228
+ )
229
+ attn_bias = impl .unary .log (
230
+ ctx , target , source_ir , name + "_log" , one_minus_temp_mask
231
+ )
228
232
229
233
scaled_add_attn_bias = impl .elementwise .add (
230
234
ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_bias
0 commit comments