Skip to content

Commit 44e0eb1

Browse files
committed
consolidate acc printout
Signed-off-by: Ye Yu <[email protected]>
1 parent b528ccc commit 44e0eb1

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ def forward(
10851085
if self.eagle_freeze_base_model:
10861086
loss = 0.0 * loss
10871087

1088+
acc = []
10881089
eagle_hidden_states_pre_norm = None
10891090
for ttt_step in range(ttt_steps):
10901091
eagle_logits = []
@@ -1136,7 +1137,6 @@ def forward(
11361137
)
11371138

11381139
if self.eagle_report_acc and not self.training:
1139-
acc = []
11401140
with torch.no_grad():
11411141
for i in range(self.eagle_config.parallel_draft_step):
11421142
gathered_logits = gather_from_tensor_model_parallel_region(
@@ -1152,12 +1152,12 @@ def forward(
11521152
)
11531153
acc.append(top1_p)
11541154

1155-
if get_tensor_model_parallel_rank() == 0:
1156-
print(
1157-
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}"
1158-
f"EAGLE 1st Top-1: {acc}",
1159-
flush=True,
1160-
)
1155+
if self.eagle_report_acc and not self.training and get_tensor_model_parallel_rank() == 0:
1156+
print(
1157+
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}"
1158+
f"EAGLE Top-1: {acc}",
1159+
flush=True,
1160+
)
11611161

11621162
return loss
11631163

0 commit comments

Comments
 (0)