@@ -305,65 +305,25 @@ def set_multi_step_attention_mask(attn_mask, step):
305
305
=======================================================================================================================
306
306
""" # noqa: E501
307
307
assert step > 1 , "step should be larger than 1 in multi-step attention mask."
308
- assert step <= 4 , "Currently only a step of 4 or smaller is supported!"
309
308
310
309
s = attn_mask .shape [- 1 ]
311
- zero_mask = torch .ones_like (attn_mask ).bool ()
312
- mask_2_1 = attn_mask .clone ().detach ()
313
- mask_2_1 [:, :, :, :- 1 ] = mask_2_1 [:, :, :, 1 :]
314
- mask_2_2 = torch .ones_like (attn_mask ).bool ()
315
- for i in range (1 , s - 1 ):
316
- mask_2_2 [:, :, i , i ] = False
317
-
318
- if step == 2 :
319
- attn_mask = torch .cat (
320
- (
321
- torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
322
- torch .cat ((mask_2_1 , mask_2_2 ), dim = - 1 ),
323
- ),
324
- dim = - 2 ,
325
- )
326
- return attn_mask
327
-
328
- mask_3_1 = mask_2_1 .clone ().detach ()
329
- mask_3_1 [:, :, :, :- 1 ] = mask_3_1 [:, :, :, 1 :]
330
- mask_3_2 = mask_2_2 .clone ().detach ()
331
- mask_3_2 [:, :, :, :- 1 ] = mask_3_2 [:, :, :, 1 :]
332
- mask_3_2 [:, :, 1 , 0 ] = True
333
- mask_3_3 = mask_2_2 .clone ().detach ()
334
- mask_3_3 [:, :, 1 , 1 ] = True
310
+ for iter in range (2 , step + 1 ):
311
+ # iter starts from 2nd step
312
+ zero_mask = torch .ones (attn_mask .shape [0 ], attn_mask .shape [1 ], attn_mask .shape [2 ], s ).bool ()
313
+ mask_0 = attn_mask .clone ().detach ()[:, :, - s :, :]
314
+ mask_0 [:, :, iter - 2 ] = True
315
+ mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
316
+ mask_1 = torch .ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
317
+ for i in range (iter - 1 , s - 1 ):
318
+ mask_1 [:, :, i , i ] = False
335
319
336
- if step == 3 :
337
320
attn_mask = torch .cat (
338
321
(
339
- torch .cat ((attn_mask , zero_mask , zero_mask ), dim = - 1 ),
340
- torch .cat ((mask_2_1 , mask_2_2 , zero_mask ), dim = - 1 ),
341
- torch .cat ((mask_3_1 , mask_3_2 , mask_3_3 ), dim = - 1 ),
322
+ torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
323
+ torch .cat ((mask_0 , mask_1 ), dim = - 1 ),
342
324
),
343
325
dim = - 2 ,
344
326
)
345
- return attn_mask
346
-
347
- mask_4_1 = mask_3_1 .clone ().detach ()
348
- mask_4_1 [:, :, :, :- 1 ] = mask_4_1 [:, :, :, 1 :]
349
- mask_4_2 = mask_3_2 .clone ().detach ()
350
- mask_4_2 [:, :, :, :- 1 ] = mask_4_2 [:, :, :, 1 :]
351
- mask_4_2 [:, :, 2 , 0 ] = True
352
- mask_4_3 = mask_3_3 .clone ().detach ()
353
- mask_4_3 [:, :, :, :- 1 ] = mask_4_3 [:, :, :, 1 :]
354
- mask_4_3 [:, :, 2 , 1 ] = True
355
- mask_4_4 = mask_3_3 .clone ().detach ()
356
- mask_4_4 [:, :, 2 , 2 ] = True
357
-
358
- attn_mask = torch .cat (
359
- (
360
- torch .cat ((attn_mask , zero_mask , zero_mask , zero_mask ), dim = - 1 ),
361
- torch .cat ((mask_2_1 , mask_2_2 , zero_mask , zero_mask ), dim = - 1 ),
362
- torch .cat ((mask_3_1 , mask_3_2 , mask_3_3 , zero_mask ), dim = - 1 ),
363
- torch .cat ((mask_4_1 , mask_4_2 , mask_4_3 , mask_4_4 ), dim = - 1 ),
364
- ),
365
- dim = - 2 ,
366
- )
367
327
return attn_mask
368
328
369
329
0 commit comments