@@ -140,7 +140,7 @@ def dict_to_config(
140
140
141
141
142
142
def mcore_version_higher_than (target_version : str ):
143
- """Check if megatron-core is least this version."""
143
+ """Check if megatron-core is greater than this version."""
144
144
return Version (megatron .core .__version__ ) > Version (target_version )
145
145
146
146
@@ -239,13 +239,13 @@ def set_multi_step_attention_mask(attn_mask, step):
239
239
=======================================================================================================================
240
240
""" # noqa: E501
241
241
s = attn_mask .shape [- 1 ]
242
- for iter in range (2 , step + 1 ):
243
- # iter starts from 2nd step
242
+ for step_idx in range (2 , step + 1 ):
243
+ # step_idx starts from 2nd step
244
244
mask_0 = attn_mask .clone ().detach ()
245
- mask_0 [:, :, iter - 2 , :] = True
245
+ mask_0 [:, :, step_idx - 2 , :] = True
246
246
mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
247
247
mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
248
- for i in range (iter - 1 , s - 1 ):
248
+ for i in range (step_idx - 1 , s - 1 ):
249
249
mask_1 [:, :, i , i ] = False
250
250
251
251
attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
0 commit comments