@@ -161,51 +161,77 @@ def scaled_dot_product_attention(
161
161
L = impl .shape .shape (ctx , target , source_ir , name + "_shape_0" , query , 2 )
162
162
if S < 0 :
163
163
S = impl .shape .shape (ctx , target , source_ir , name + "_shape_1" , key , 2 )
164
-
165
164
# generate the mask tensor
166
165
if is_causal :
167
166
tril_tensor = tril (ctx , target , source_ir , name + "_tril" , L , S )
168
167
else :
169
- # hard code the sliding window size to 512 for now
170
- tril_tensor = tril (ctx , target , source_ir , name + "_tril" , L , S , 512 )
171
168
# TODO: lan to figure out why attn_mask passed in from transformers is not working
172
- # tried both 2d and 4d, but both are not working, hence the following code is commented out
173
- # assert len(attn_mask.shape) in [2, 4], f"attn_mask must be 2D or 4D, but got {attn_mask.shape=}"
174
- # if len(attn_mask.shape) == 4:
175
- # if attn_mask.shape[0] != 1:
176
- # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 0, 0, 1, 1)
177
- # if attn_mask.shape[1] != 1:
178
- # attn_mask = impl.slice.slice_op(ctx, target, source_ir, name + "_slice", attn_mask, 1, 0, 1, 1)
179
- # attn_mask = impl.squeeze.squeeze(ctx, target, source_ir, name + "_squeeze", attn_mask, (0, 1))
180
- # tril_tensor = attn_mask
181
-
182
- temp_mask = impl .unary .logical_not (
183
- ctx , target , source_ir , name + "_logical_not" , tril_tensor
184
- )
169
+ # tried both 2d and 4d, but both are not working
170
+ assert len (attn_mask .shape ) in [
171
+ 2 ,
172
+ 4 ,
173
+ ], f"attn_mask must be 2D or 4D, but got { attn_mask .shape = } "
174
+ if len (attn_mask .shape ) == 4 :
175
+ if attn_mask .shape [0 ] != 1 :
176
+ attn_mask = impl .slice .slice_op (
177
+ ctx , target , source_ir , name + "_slice" , attn_mask , 0 , 0 , 1 , 1
178
+ )
179
+ if attn_mask .shape [1 ] != 1 :
180
+ attn_mask = impl .slice .slice_op (
181
+ ctx , target , source_ir , name + "_slice" , attn_mask , 1 , 0 , 1 , 1
182
+ )
183
+ attn_mask = impl .squeeze .squeeze (
184
+ ctx , target , source_ir , name + "_squeeze" , attn_mask , (0 , 1 )
185
+ )
186
+ tril_tensor = attn_mask
185
187
186
- # This need_mask determines if we want to use the causal mask or not
187
- # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
188
- # So need_mask will be all False values in this case.
189
- # TODO: Implement more general case where L != 1 and S != L
190
- need_mask = impl .elementwise .eq (ctx , target , source_ir , name + "_eq" , L , S )
191
- temp_mask = impl .elementwise .logical_and (
192
- ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
193
- )
194
- temp_mask_casted = cast_trt_tensor (
195
- ctx , temp_mask , query_dtype , name + "_casted_bool" , target , source_ir
196
- )
188
+ # generate attn_bias via where instead of (logical_and, sub, log) to see whether nan is related to this
189
+ attn_bias_via_where = True
190
+ if attn_bias_via_where :
191
+ attn_bias = impl .condition .where (
192
+ ctx ,
193
+ target ,
194
+ source_ir ,
195
+ name + "_where" ,
196
+ torch .tensor (0.0 , dtype = torch .float32 ).cuda (),
197
+ torch .tensor (- float ("inf" ), dtype = torch .float32 ).cuda (),
198
+ tril_tensor ,
199
+ )
200
+ else :
201
+ temp_mask = impl .unary .logical_not (
202
+ ctx , target , source_ir , name + "_logical_not" , tril_tensor
203
+ )
204
+ temp_mask = cast_trt_tensor (
205
+ ctx , temp_mask , trt .float32 , name + "_casted_bool" , target , source_ir
206
+ )
207
+ temp_mask = impl .elementwise .mul (
208
+ ctx , target , source_ir , name + "_mul_-inf" , temp_mask , float ("-inf" )
209
+ )
210
+ attn_bias = temp_mask
197
211
198
- one_minus_temp_mask = impl .elementwise .sub (
199
- ctx ,
200
- target ,
201
- source_ir ,
202
- name + "_one_minus_temp_mask" ,
203
- 1.0 ,
204
- temp_mask_casted ,
205
- )
206
- attn_bias = impl .unary .log (
207
- ctx , target , source_ir , name + "_log" , one_minus_temp_mask
208
- )
212
+ # This need_mask determines if we want to use the causal mask or not
213
+ # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
214
+ # So need_mask will be all False values in this case.
215
+ # TODO: Implement more general case where L != 1 and S != L
216
+ need_mask = impl .elementwise .eq (ctx , target , source_ir , name + "_eq" , L , S )
217
+ temp_mask = impl .elementwise .logical_and (
218
+ ctx , target , source_ir , name + "_logical_and" , need_mask , temp_mask
219
+ )
220
+ temp_mask_casted = cast_trt_tensor (
221
+ ctx , temp_mask , query_dtype , name + "_casted_bool" , target , source_ir
222
+ )
223
+
224
+ one_minus_temp_mask = impl .elementwise .sub (
225
+ ctx ,
226
+ target ,
227
+ source_ir ,
228
+ name + "_one_minus_temp_mask" ,
229
+ 1.0 ,
230
+ temp_mask_casted ,
231
+ )
232
+ attn_bias = impl .unary .log (
233
+ ctx , target , source_ir , name + "_log" , one_minus_temp_mask
234
+ )
209
235
210
236
scaled_add_attn_bias = impl .elementwise .add (
211
237
ctx , target , source_ir , name + "_attn_bias_add" , mm , attn_bias
0 commit comments