@@ -305,65 +305,25 @@ def set_multi_step_attention_mask(attn_mask, step):
305305 =======================================================================================================================
306306 """ # noqa: E501
307307 assert step > 1 , "step should be larger than 1 in multi-step attention mask."
308- assert step <= 4 , "Currently only a step of 4 or smaller is supported!"
309308
310309 s = attn_mask .shape [- 1 ]
311- zero_mask = torch .ones_like (attn_mask ).bool ()
312- mask_2_1 = attn_mask .clone ().detach ()
313- mask_2_1 [:, :, :, :- 1 ] = mask_2_1 [:, :, :, 1 :]
314- mask_2_2 = torch .ones_like (attn_mask ).bool ()
315- for i in range (1 , s - 1 ):
316- mask_2_2 [:, :, i , i ] = False
317-
318- if step == 2 :
319- attn_mask = torch .cat (
320- (
321- torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
322- torch .cat ((mask_2_1 , mask_2_2 ), dim = - 1 ),
323- ),
324- dim = - 2 ,
325- )
326- return attn_mask
327-
328- mask_3_1 = mask_2_1 .clone ().detach ()
329- mask_3_1 [:, :, :, :- 1 ] = mask_3_1 [:, :, :, 1 :]
330- mask_3_2 = mask_2_2 .clone ().detach ()
331- mask_3_2 [:, :, :, :- 1 ] = mask_3_2 [:, :, :, 1 :]
332- mask_3_2 [:, :, 1 , 0 ] = True
333- mask_3_3 = mask_2_2 .clone ().detach ()
334- mask_3_3 [:, :, 1 , 1 ] = True
310+ for iter in range (2 , step + 1 ):
311+ # iter starts from 2nd step
312+ zero_mask = torch .ones (attn_mask .shape [0 ], attn_mask .shape [1 ], attn_mask .shape [2 ], s ).bool ()
313+ mask_0 = attn_mask .clone ().detach ()[:, :, - s :, :]
314+ mask_0 [:, :, iter - 2 ] = True
315+ mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
316+ mask_1 = torch .ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
317+ for i in range (iter - 1 , s - 1 ):
318+ mask_1 [:, :, i , i ] = False
335319
336- if step == 3 :
337320 attn_mask = torch .cat (
338321 (
339- torch .cat ((attn_mask , zero_mask , zero_mask ), dim = - 1 ),
340- torch .cat ((mask_2_1 , mask_2_2 , zero_mask ), dim = - 1 ),
341- torch .cat ((mask_3_1 , mask_3_2 , mask_3_3 ), dim = - 1 ),
322+ torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
323+ torch .cat ((mask_0 , mask_1 ), dim = - 1 ),
342324 ),
343325 dim = - 2 ,
344326 )
345- return attn_mask
346-
347- mask_4_1 = mask_3_1 .clone ().detach ()
348- mask_4_1 [:, :, :, :- 1 ] = mask_4_1 [:, :, :, 1 :]
349- mask_4_2 = mask_3_2 .clone ().detach ()
350- mask_4_2 [:, :, :, :- 1 ] = mask_4_2 [:, :, :, 1 :]
351- mask_4_2 [:, :, 2 , 0 ] = True
352- mask_4_3 = mask_3_3 .clone ().detach ()
353- mask_4_3 [:, :, :, :- 1 ] = mask_4_3 [:, :, :, 1 :]
354- mask_4_3 [:, :, 2 , 1 ] = True
355- mask_4_4 = mask_3_3 .clone ().detach ()
356- mask_4_4 [:, :, 2 , 2 ] = True
357-
358- attn_mask = torch .cat (
359- (
360- torch .cat ((attn_mask , zero_mask , zero_mask , zero_mask ), dim = - 1 ),
361- torch .cat ((mask_2_1 , mask_2_2 , zero_mask , zero_mask ), dim = - 1 ),
362- torch .cat ((mask_3_1 , mask_3_2 , mask_3_3 , zero_mask ), dim = - 1 ),
363- torch .cat ((mask_4_1 , mask_4_2 , mask_4_3 , mask_4_4 ), dim = - 1 ),
364- ),
365- dim = - 2 ,
366- )
367327 return attn_mask
368328
369329
0 commit comments