Skip to content

Commit 3c06000

Browse files
committed
minor edit based on coderabbit's suggestions
Signed-off-by: Ye Yu <[email protected]>
1 parent 4061627 commit 3c06000

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def dict_to_config(
140140

141141

142142
def mcore_version_higher_than(target_version: str):
143-
"""Check if megatron-core is least this version."""
143+
"""Check if megatron-core is greater than this version."""
144144
return Version(megatron.core.__version__) > Version(target_version)
145145

146146

@@ -239,13 +239,13 @@ def set_multi_step_attention_mask(attn_mask, step):
239239
=======================================================================================================================
240240
""" # noqa: E501
241241
s = attn_mask.shape[-1]
242-
for iter in range(2, step + 1):
243-
# iter starts from 2nd step
242+
for step_idx in range(2, step + 1):
243+
# step_idx starts from 2nd step
244244
mask_0 = attn_mask.clone().detach()
245-
mask_0[:, :, iter - 2, :] = True
245+
mask_0[:, :, step_idx - 2, :] = True
246246
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
247247
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
248-
for i in range(iter - 1, s - 1):
248+
for i in range(step_idx - 1, s - 1):
249249
mask_1[:, :, i, i] = False
250250

251251
attn_mask = torch.cat((mask_0, mask_1), dim=-1)

modelopt/torch/speculative/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismat
293293
"""This function checks the data consistency across all ranks in the group.
294294
295295
Use rank 0 data as the golden set to broadcast to all ranks.
296-
Each rank will then compare to this data and through error if different.
296+
Each rank compares its data against this golden set and either raises
297+
(when fail_when_mismatch=True) or emits a warning while forcing every
298+
rank to adopt rank 0's data.
297299
"""
298300
if not torch.distributed.is_initialized():
299301
return data

0 commit comments

Comments
 (0)