Skip to content

Commit 8ad781a

Browse files
committed
make ttt_step configurable in forward
Signed-off-by: Ye Yu <[email protected]>
1 parent 58f8710 commit 8ad781a

File tree

1 file changed

+77
-215
lines changed

1 file changed

+77
-215
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 77 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def set_multi_step_attention_mask(attn_mask, step):
197197
198198
ttt_step=2
199199
parallel_draft_step=2
200+
->step=3
200201
201202
| i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
202203
(out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
@@ -239,13 +240,12 @@ def set_multi_step_attention_mask(attn_mask, step):
239240
=======================================================================================================================
240241
""" # noqa: E501
241242
s = attn_mask.shape[-1]
242-
for step_idx in range(2, step + 1):
243-
# step_idx starts from 2nd step
243+
for step_idx in range(step):
244244
mask_0 = attn_mask.clone().detach()
245-
mask_0[:, :, step_idx - 2, :] = True
245+
mask_0[:, :, step_idx, :] = True
246246
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
247247
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
248-
for i in range(step_idx - 1, s - 1):
248+
for i in range(step_idx + 1, s - 1):
249249
mask_1[:, :, i, i] = False
250250

251251
attn_mask = torch.cat((mask_0, mask_1), dim=-1)
@@ -759,8 +759,8 @@ def _get_eagle_module_inputs(
759759
attention_mask: torch.Tensor,
760760
position_ids: torch.Tensor,
761761
features: torch.Tensor | None = None,
762-
ttt_step: int = 1,
763-
parallel_draft_step: int = 1,
762+
ttt_step: int = 0,
763+
parallel_draft_step: int = 0,
764764
):
765765
"""Getting EAGLE module inputs."""
766766
b = hidden_states.shape[1]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784784

785785
eagle_inputs["input_ids"] = (
786786
padded_input_ids
787-
if parallel_draft_step == 1
787+
if parallel_draft_step == 0
788788
else torch.full(
789789
padded_input_ids.shape,
790-
getattr(self, f"mask_token_{parallel_draft_step - 2}"),
790+
getattr(self, f"mask_token_{parallel_draft_step - 1}"),
791791
device=padded_input_ids.device,
792792
dtype=padded_input_ids.dtype,
793793
)
@@ -805,7 +805,7 @@ def _get_eagle_module_inputs(
805805
feature = gathered_features[-s:]
806806
eagle_inputs["hidden_states"] = (
807807
gathered_hidden_states
808-
if ttt_step == 1
808+
if ttt_step == 0
809809
else torch.cat(
810810
(
811811
torch.zeros(
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824824
)
825825

826826
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
827-
attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step
827+
attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step
828828
)
829829

830830
eagle_inputs["rotary_pos_emb"] = torch.cat(
831831
[rotary_pos_emb]
832-
* ((ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step),
832+
* (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step + 1),
833833
dim=0,
834834
)
835835

@@ -970,6 +970,7 @@ def forward(
970970
packed_seq_params: PackedSeqParams = None,
971971
extra_block_kwargs: dict | None = None,
972972
return_eagle_inputs: bool = False,
973+
ttt_steps=4,
973974
**kwargs,
974975
) -> torch.Tensor:
975976
if position_ids is None or attention_mask is None:
@@ -1013,7 +1014,8 @@ def forward(
10131014

10141015
# EAGLE kv cache
10151016
eagle_inference_context = StaticInferenceContext(
1016-
input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4
1017+
input_ids.shape[0],
1018+
input_ids.shape[1] * self.eagle_config.parallel_draft_step * ttt_steps,
10171019
)
10181020

10191021
if self.eagle_offline:
@@ -1043,228 +1045,88 @@ def forward(
10431045
hidden_states, apply_fc=True
10441046
)
10451047

1046-
# In calibration mode, we want to make sure all weights have been exercised.
1047-
# This makes sure all quantized weights have amax calibrated
1048-
if inference_params is None or self.calibration_mode:
1049-
eagle_logits_0 = []
1048+
if labels is not None:
1049+
if labels.shape[1] == input_ids.shape[1] - 1:
1050+
# For offline training, labels may be 1 token shorter than input_ids.
1051+
# We will just pad a 0 to the labels to make the seq_len the same as
1052+
# input_ids. This will introduce a small error in training if logit_distillation
1053+
# is False, and testing accuracy is wrong for the last token.
1054+
right_token_pad = torch.zeros(
1055+
(labels.shape[0], 1),
1056+
dtype=labels.dtype,
1057+
device=labels.device,
1058+
)
1059+
labels = torch.cat((labels, right_token_pad), dim=-1)
1060+
1061+
# If eagle_freeze_base_model is set to True,
1062+
# the base model is frozen .
1063+
loss = self.compute_language_model_loss(labels, logits_sbh)
1064+
if self.eagle_freeze_base_model:
1065+
loss = 0.0 * loss
1066+
1067+
eagle_hidden_states_pre_norm = None
1068+
for ttt_step in range(ttt_steps):
1069+
eagle_logits = []
10501070
for i in range(self.eagle_config.parallel_draft_step):
1051-
eagle_inputs_0 = self._get_eagle_module_inputs(
1071+
eagle_inputs = self._get_eagle_module_inputs(
10521072
input_ids=input_ids,
10531073
hidden_states=eagle_module_input_hidden_states,
10541074
attention_mask=attention_mask,
10551075
position_ids=position_ids,
1056-
ttt_step=1,
1057-
parallel_draft_step=i + 1,
1076+
features=eagle_hidden_states_pre_norm,
1077+
ttt_step=ttt_step,
1078+
parallel_draft_step=i,
10581079
)
10591080

1060-
_, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward(
1061-
eagle_inputs_0,
1081+
_, eagle_logits_, eagle_hidden_states_pre_norm_ = self._eagle_forward(
1082+
eagle_inputs,
10621083
output_weight,
10631084
inference_params=inference_params,
10641085
packed_seq_params=packed_seq_params,
10651086
inference_context=eagle_inference_context,
10661087
**(extra_block_kwargs or {}),
10671088
)
10681089

1069-
eagle_logits_0.append(eagle_logits_)
1070-
eagle_logits_0 = torch.cat(eagle_logits_0, dim=0)
1071-
1072-
# If labels are not provided, return the original logits. We only return after
1073-
# all eagle weights have been exercised for quantization calibration purpose.
1074-
if labels is None:
1075-
return logits_sbh.transpose(0, 1).contiguous()
1076-
elif labels.shape[1] == input_ids.shape[1] - 1:
1077-
# For offline training, labels may be 1 token shorter than input_ids.
1078-
# We will just pad a 0 to the labels to make the seq_len the same as
1079-
# input_ids. This will introduce a small error in training if logit_distillation
1080-
# is False, and testing accuracy is wrong for the last token.
1081-
right_token_pad = torch.zeros(
1082-
(labels.shape[0], 1),
1083-
dtype=labels.dtype,
1084-
device=labels.device,
1085-
)
1086-
labels = torch.cat((labels, right_token_pad), dim=-1)
1087-
1088-
# If eagle_freeze_base_model is set to True,
1089-
# the base model is frozen .
1090-
loss = self.compute_language_model_loss(labels, logits_sbh)
1091-
loss = 0.0 * loss
1092-
1093-
for i in range(self.eagle_config.parallel_draft_step):
1094-
eagle_logits = eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1095-
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
1096-
loss_ = loss_[:, i:]
1097-
loss[:, i + 1 :] += self.eagle_loss_decay_factor * loss_
1098-
1099-
if self.eagle_report_acc and not self.training:
1100-
acc = []
1101-
with torch.no_grad():
1102-
for i in range(self.eagle_config.parallel_draft_step):
1103-
gathered_logits = gather_from_tensor_model_parallel_region(
1104-
eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1105-
)
1106-
gathered_logits = gathered_logits[i:-1]
1107-
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
1108-
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
1109-
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1110-
top1_p = torch.eq(labels[:, i + 1 :], eagle_top1).sum() / eagle_top1.numel()
1111-
acc.append(top1_p)
1112-
1113-
if get_tensor_model_parallel_rank() == 0:
1114-
print(
1115-
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}",
1116-
flush=True,
1117-
)
1118-
1119-
# Second round of EAGLE loss
1120-
eagle_logits_1 = []
1121-
for i in range(self.eagle_config.parallel_draft_step):
1122-
eagle_inputs_1 = self._get_eagle_module_inputs(
1123-
input_ids=input_ids,
1124-
hidden_states=eagle_module_input_hidden_states,
1125-
attention_mask=attention_mask,
1126-
position_ids=position_ids,
1127-
features=eagle_hidden_states_0_pre_norm,
1128-
ttt_step=2,
1129-
parallel_draft_step=i + 1,
1130-
)
1090+
eagle_logits.append(eagle_logits_)
1091+
eagle_logits = torch.cat(eagle_logits, dim=0)
1092+
eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
11311093

1132-
_, eagle_logits_, eagle_hidden_states_2x_pre_norm = self._eagle_forward(
1133-
eagle_inputs_1,
1134-
output_weight,
1135-
inference_params=inference_params,
1136-
packed_seq_params=packed_seq_params,
1137-
inference_context=eagle_inference_context,
1138-
**(extra_block_kwargs or {}),
1139-
)
1140-
eagle_logits_1.append(eagle_logits_)
1141-
eagle_logits_1 = torch.cat(eagle_logits_1, dim=0)
1142-
1143-
for i in range(self.eagle_config.parallel_draft_step):
1144-
eagle_logits = eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1145-
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
1146-
loss_ = loss_[:, i + 1 :]
1147-
loss[:, i + 2 :] += self.eagle_loss_decay_factor**2 * loss_
1148-
1149-
if self.eagle_report_acc and not self.training:
1150-
acc = []
1151-
with torch.no_grad():
1152-
for i in range(self.eagle_config.parallel_draft_step):
1153-
gathered_logits = gather_from_tensor_model_parallel_region(
1154-
eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1155-
)
1156-
gathered_logits = gathered_logits[i + 1 : -1]
1157-
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
1158-
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
1159-
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1160-
top1_p = torch.eq(labels[:, i + 2 :], eagle_top1).sum() / eagle_top1.numel()
1161-
acc.append(top1_p)
1162-
1163-
if get_tensor_model_parallel_rank() == 0:
1164-
print(
1165-
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}",
1166-
flush=True,
1167-
)
1094+
# If labels are not provided, return the original logits. We only return after
1095+
# all eagle weights have been exercised for quantization calibration purpose.
1096+
if labels is None:
1097+
return logits_sbh.transpose(0, 1).contiguous()
11681098

1169-
# Third EAGLE loss
1170-
eagle_logits_2 = []
1171-
for i in range(self.eagle_config.parallel_draft_step):
1172-
eagle_inputs_2 = self._get_eagle_module_inputs(
1173-
input_ids=input_ids,
1174-
hidden_states=eagle_module_input_hidden_states,
1175-
attention_mask=attention_mask,
1176-
position_ids=position_ids,
1177-
features=eagle_hidden_states_2x_pre_norm,
1178-
ttt_step=3,
1179-
parallel_draft_step=i + 1,
1180-
)
1181-
1182-
_, eagle_logits_, eagle_hidden_states_3x_pre_norm = self._eagle_forward(
1183-
eagle_inputs_2,
1184-
output_weight,
1185-
inference_params=inference_params,
1186-
packed_seq_params=packed_seq_params,
1187-
inference_context=eagle_inference_context,
1188-
**(extra_block_kwargs or {}),
1189-
)
1190-
eagle_logits_2.append(eagle_logits_)
1191-
eagle_logits_2 = torch.cat(eagle_logits_2, dim=0)
1192-
1193-
for i in range(self.eagle_config.parallel_draft_step):
1194-
eagle_logits = eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1195-
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
1196-
loss_ = loss_[:, i + 2 :]
1197-
loss[:, i + 3 :] += self.eagle_loss_decay_factor**3 * loss_
1198-
1199-
if self.eagle_report_acc and not self.training:
1200-
acc = []
1201-
with torch.no_grad():
1202-
for i in range(self.eagle_config.parallel_draft_step):
1203-
gathered_logits = gather_from_tensor_model_parallel_region(
1204-
eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1205-
)
1206-
gathered_logits = gathered_logits[i + 2 : -1]
1207-
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
1208-
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
1209-
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1210-
top1_p = torch.eq(labels[:, i + 3 :], eagle_top1).sum() / eagle_top1.numel()
1211-
acc.append(top1_p)
1212-
1213-
if get_tensor_model_parallel_rank() == 0:
1214-
print(
1215-
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}",
1216-
flush=True,
1099+
for i in range(self.eagle_config.parallel_draft_step):
1100+
eagle_logit = eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1101+
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logit)
1102+
loss_ = loss_[:, i + ttt_step :]
1103+
loss[:, i + ttt_step + 1 :] += (
1104+
self.eagle_loss_decay_factor ** (ttt_step + i) * loss_
12171105
)
12181106

1219-
# Forth EAGLE loss
1220-
eagle_logits_3 = []
1221-
for i in range(self.eagle_config.parallel_draft_step):
1222-
eagle_inputs_3 = self._get_eagle_module_inputs(
1223-
input_ids=input_ids,
1224-
hidden_states=eagle_module_input_hidden_states,
1225-
attention_mask=attention_mask,
1226-
position_ids=position_ids,
1227-
features=eagle_hidden_states_3x_pre_norm,
1228-
ttt_step=4,
1229-
parallel_draft_step=i + 1,
1230-
)
1231-
1232-
_, eagle_logits_, eagle_hidden_states_4x_pre_norm = self._eagle_forward(
1233-
eagle_inputs_3,
1234-
output_weight,
1235-
inference_params=inference_params,
1236-
packed_seq_params=packed_seq_params,
1237-
inference_context=eagle_inference_context,
1238-
**(extra_block_kwargs or {}),
1239-
)
1240-
eagle_logits_3.append(eagle_logits_)
1241-
eagle_logits_3 = torch.cat(eagle_logits_3, dim=0)
1242-
1243-
for i in range(self.eagle_config.parallel_draft_step):
1244-
eagle_logits = eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1245-
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
1246-
loss_ = loss_[:, i + 3 :]
1247-
loss[:, i + 4 :] += self.eagle_loss_decay_factor**4 * loss_
1248-
1249-
if self.eagle_report_acc and not self.training:
1250-
acc = []
1251-
with torch.no_grad():
1252-
for i in range(self.eagle_config.parallel_draft_step):
1253-
gathered_logits = gather_from_tensor_model_parallel_region(
1254-
eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1107+
if self.eagle_report_acc and not self.training:
1108+
acc = []
1109+
with torch.no_grad():
1110+
for i in range(self.eagle_config.parallel_draft_step):
1111+
gathered_logits = gather_from_tensor_model_parallel_region(
1112+
eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
1113+
)
1114+
gathered_logits = gathered_logits[i + ttt_step : -1]
1115+
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
1116+
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
1117+
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1118+
top1_p = (
1119+
torch.eq(labels[:, i + ttt_step + 1 :], eagle_top1).sum()
1120+
/ eagle_top1.numel()
1121+
)
1122+
acc.append(top1_p)
1123+
1124+
if get_tensor_model_parallel_rank() == 0:
1125+
print(
1126+
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}"
1127+
f"EAGLE 1st Top-1: {acc}",
1128+
flush=True,
12551129
)
1256-
gathered_logits = gathered_logits[i + 3 : -1]
1257-
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
1258-
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
1259-
eagle_top1 += self.eagle_module.d2t[eagle_top1]
1260-
top1_p = torch.eq(labels[:, i + 4 :], eagle_top1).sum() / eagle_top1.numel()
1261-
acc.append(top1_p)
1262-
1263-
if get_tensor_model_parallel_rank() == 0:
1264-
print(
1265-
f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}",
1266-
flush=True,
1267-
)
12681130

12691131
return loss
12701132

0 commit comments

Comments
 (0)