@@ -140,7 +140,7 @@ def dict_to_config(
140140
141141
142142def mcore_version_higher_than (target_version : str ):
143- """Check if megatron-core is least this version."""
143+ """Check if megatron-core is greater than this version."""
144144 return Version (megatron .core .__version__ ) > Version (target_version )
145145
146146
@@ -239,13 +239,13 @@ def set_multi_step_attention_mask(attn_mask, step):
239239 =======================================================================================================================
240240 """ # noqa: E501
241241 s = attn_mask .shape [- 1 ]
242- for iter in range (2 , step + 1 ):
243- # iter starts from 2nd step
242+ for step_idx in range (2 , step + 1 ):
243+ # step_idx starts from 2nd step
244244 mask_0 = attn_mask .clone ().detach ()
245- mask_0 [:, :, iter - 2 , :] = True
245+ mask_0 [:, :, step_idx - 2 , :] = True
246246 mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
247247 mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
248- for i in range (iter - 1 , s - 1 ):
248+ for i in range (step_idx - 1 , s - 1 ):
249249 mask_1 [:, :, i , i ] = False
250250
251251 attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
0 commit comments