Skip to content

Commit 651cbf4

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent 292ec59 commit 651cbf4

File tree

1 file changed

+96
-83
lines changed

1 file changed

+96
-83
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def _get_eagle_module_inputs(
760760
position_ids: torch.Tensor,
761761
features: torch.Tensor | None = None,
762762
ttt_step: int = 1,
763+
parallel_draft_step: int = 1,
763764
):
764765
"""Getting EAGLE module inputs."""
765766
b = hidden_states.shape[1]
@@ -801,7 +802,7 @@ def _get_eagle_module_inputs(
801802
)
802803

803804
for step in range(ttt_step):
804-
for i in range(self.eagle_config.parallel_draft_step):
805+
for i in range(parallel_draft_step):
805806
eagle_inputs["input_ids"] = torch.cat(
806807
(
807808
eagle_inputs["input_ids"],
@@ -818,11 +819,7 @@ def _get_eagle_module_inputs(
818819
)
819820

820821
if step > 0:
821-
feature = gathered_features[
822-
(step * self.eagle_config.parallel_draft_step - 1) * s : step
823-
* self.eagle_config.parallel_draft_step
824-
* s
825-
]
822+
feature = gathered_features[-s:]
826823
eagle_inputs["hidden_states"] = torch.cat(
827824
(
828825
eagle_inputs["hidden_states"],
@@ -857,7 +854,7 @@ def _get_eagle_module_inputs(
857854
)
858855

859856
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
860-
attn_mask, ttt_step * self.eagle_config.parallel_draft_step
857+
attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step
861858
)
862859

863860
eagle_inputs["embedding"] = self.embedding(
@@ -1072,21 +1069,28 @@ def forward(
10721069
# In calibration mode, we want to make sure all weights have been exercised.
10731070
# This makes sure all quantized weights have amax calibrated
10741071
if inference_params is None or self.calibration_mode:
1075-
eagle_inputs_0 = self._get_eagle_module_inputs(
1076-
input_ids=input_ids,
1077-
hidden_states=eagle_module_input_hidden_states,
1078-
attention_mask=attention_mask,
1079-
position_ids=position_ids,
1080-
)
1072+
eagle_logits_0 = []
1073+
for i in range(self.eagle_config.parallel_draft_step):
1074+
eagle_inputs_0 = self._get_eagle_module_inputs(
1075+
input_ids=input_ids,
1076+
hidden_states=eagle_module_input_hidden_states,
1077+
attention_mask=attention_mask,
1078+
position_ids=position_ids,
1079+
ttt_step=1,
1080+
parallel_draft_step=i + 1,
1081+
)
10811082

1082-
_, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward(
1083-
eagle_inputs_0,
1084-
output_weight,
1085-
inference_params=inference_params,
1086-
packed_seq_params=packed_seq_params,
1087-
inference_context=eagle_inference_context,
1088-
**(extra_block_kwargs or {}),
1089-
)
1083+
_, eagle_logits_, eagle_hidden_states_0_pre_norm = self._eagle_forward(
1084+
eagle_inputs_0,
1085+
output_weight,
1086+
inference_params=inference_params,
1087+
packed_seq_params=packed_seq_params,
1088+
inference_context=eagle_inference_context,
1089+
**(extra_block_kwargs or {}),
1090+
)
1091+
1092+
eagle_logits_0.append(eagle_logits_[-input_ids.shape[1] :])
1093+
eagle_logits_0 = torch.cat(eagle_logits_0, dim=0)
10901094

10911095
# If labels are not provided, return the original logits. We only return after
10921096
# all eagle weights have been exercised for quantization calibration purpose.
@@ -1109,9 +1113,8 @@ def forward(
11091113
loss = self.compute_language_model_loss(labels, logits_sbh)
11101114
loss = 0.0 * loss
11111115

1112-
eagle_logits_0 = eagle_logits_0[-labels.shape[1] * self.eagle_config.parallel_draft_step :]
11131116
for i in range(self.eagle_config.parallel_draft_step):
1114-
eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1117+
eagle_logits = eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
11151118
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
11161119
loss_ = loss_[:, i:]
11171120
loss[:, i + 1 :] += self.eagle_loss_decay_factor * loss_
@@ -1121,7 +1124,7 @@ def forward(
11211124
with torch.no_grad():
11221125
for i in range(self.eagle_config.parallel_draft_step):
11231126
gathered_logits = gather_from_tensor_model_parallel_region(
1124-
eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1127+
eagle_logits_0[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
11251128
)
11261129
gathered_logits = gathered_logits[i:-1]
11271130
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
@@ -1137,27 +1140,31 @@ def forward(
11371140
)
11381141

11391142
# Second round of EAGLE loss
1140-
eagle_inputs_1 = self._get_eagle_module_inputs(
1141-
input_ids=input_ids,
1142-
hidden_states=eagle_module_input_hidden_states,
1143-
attention_mask=attention_mask,
1144-
position_ids=position_ids,
1145-
features=eagle_hidden_states_0_pre_norm,
1146-
ttt_step=2,
1147-
)
1143+
eagle_logits_1 = []
1144+
for i in range(self.eagle_config.parallel_draft_step):
1145+
eagle_inputs_1 = self._get_eagle_module_inputs(
1146+
input_ids=input_ids,
1147+
hidden_states=eagle_module_input_hidden_states,
1148+
attention_mask=attention_mask,
1149+
position_ids=position_ids,
1150+
features=eagle_hidden_states_0_pre_norm,
1151+
ttt_step=2,
1152+
parallel_draft_step=i + 1,
1153+
)
11481154

1149-
_, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward(
1150-
eagle_inputs_1,
1151-
output_weight,
1152-
inference_params=inference_params,
1153-
packed_seq_params=packed_seq_params,
1154-
inference_context=eagle_inference_context,
1155-
**(extra_block_kwargs or {}),
1156-
)
1157-
eagle_logits_1 = eagle_logits_2x[-labels.shape[1] * self.eagle_config.parallel_draft_step :]
1155+
_, eagle_logits_, eagle_hidden_states_2x_pre_norm = self._eagle_forward(
1156+
eagle_inputs_1,
1157+
output_weight,
1158+
inference_params=inference_params,
1159+
packed_seq_params=packed_seq_params,
1160+
inference_context=eagle_inference_context,
1161+
**(extra_block_kwargs or {}),
1162+
)
1163+
eagle_logits_1.append(eagle_logits_[-input_ids.shape[1] :])
1164+
eagle_logits_1 = torch.cat(eagle_logits_1, dim=0)
11581165

11591166
for i in range(self.eagle_config.parallel_draft_step):
1160-
eagle_logits = eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1167+
eagle_logits = eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
11611168
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
11621169
loss_ = loss_[:, i + 1 :]
11631170
loss[:, i + 2 :] += self.eagle_loss_decay_factor**2 * loss_
@@ -1167,7 +1174,7 @@ def forward(
11671174
with torch.no_grad():
11681175
for i in range(self.eagle_config.parallel_draft_step):
11691176
gathered_logits = gather_from_tensor_model_parallel_region(
1170-
eagle_logits_1[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1177+
eagle_logits_1[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
11711178
)
11721179
gathered_logits = gathered_logits[i + 1 : -1]
11731180
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
@@ -1183,28 +1190,31 @@ def forward(
11831190
)
11841191

11851192
# Third EAGLE loss
1186-
eagle_inputs_2 = self._get_eagle_module_inputs(
1187-
input_ids=input_ids,
1188-
hidden_states=eagle_module_input_hidden_states,
1189-
attention_mask=attention_mask,
1190-
position_ids=position_ids,
1191-
features=eagle_hidden_states_2x_pre_norm,
1192-
ttt_step=3,
1193-
)
1194-
1195-
_, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward(
1196-
eagle_inputs_2,
1197-
output_weight,
1198-
inference_params=inference_params,
1199-
packed_seq_params=packed_seq_params,
1200-
inference_context=eagle_inference_context,
1201-
**(extra_block_kwargs or {}),
1202-
)
1193+
eagle_logits_2 = []
1194+
for i in range(self.eagle_config.parallel_draft_step):
1195+
eagle_inputs_2 = self._get_eagle_module_inputs(
1196+
input_ids=input_ids,
1197+
hidden_states=eagle_module_input_hidden_states,
1198+
attention_mask=attention_mask,
1199+
position_ids=position_ids,
1200+
features=eagle_hidden_states_2x_pre_norm,
1201+
ttt_step=3,
1202+
parallel_draft_step=i + 1,
1203+
)
12031204

1204-
eagle_logits_2 = eagle_logits_3x[-labels.shape[1] * self.eagle_config.parallel_draft_step :]
1205+
_, eagle_logits_, eagle_hidden_states_3x_pre_norm = self._eagle_forward(
1206+
eagle_inputs_2,
1207+
output_weight,
1208+
inference_params=inference_params,
1209+
packed_seq_params=packed_seq_params,
1210+
inference_context=eagle_inference_context,
1211+
**(extra_block_kwargs or {}),
1212+
)
1213+
eagle_logits_2.append(eagle_logits_[-input_ids.shape[1] :])
1214+
eagle_logits_2 = torch.cat(eagle_logits_2, dim=0)
12051215

12061216
for i in range(self.eagle_config.parallel_draft_step):
1207-
eagle_logits = eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1217+
eagle_logits = eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
12081218
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
12091219
loss_ = loss_[:, i + 2 :]
12101220
loss[:, i + 3 :] += self.eagle_loss_decay_factor**3 * loss_
@@ -1214,7 +1224,7 @@ def forward(
12141224
with torch.no_grad():
12151225
for i in range(self.eagle_config.parallel_draft_step):
12161226
gathered_logits = gather_from_tensor_model_parallel_region(
1217-
eagle_logits_2[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1227+
eagle_logits_2[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
12181228
)
12191229
gathered_logits = gathered_logits[i + 2 : -1]
12201230
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)
@@ -1230,28 +1240,31 @@ def forward(
12301240
)
12311241

12321242
# Forth EAGLE loss
1233-
eagle_inputs_3 = self._get_eagle_module_inputs(
1234-
input_ids=input_ids,
1235-
hidden_states=eagle_module_input_hidden_states,
1236-
attention_mask=attention_mask,
1237-
position_ids=position_ids,
1238-
features=eagle_hidden_states_3x_pre_norm,
1239-
ttt_step=4,
1240-
)
1241-
1242-
_, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward(
1243-
eagle_inputs_3,
1244-
output_weight,
1245-
inference_params=inference_params,
1246-
packed_seq_params=packed_seq_params,
1247-
inference_context=eagle_inference_context,
1248-
**(extra_block_kwargs or {}),
1249-
)
1243+
eagle_logits_3 = []
1244+
for i in range(self.eagle_config.parallel_draft_step):
1245+
eagle_inputs_3 = self._get_eagle_module_inputs(
1246+
input_ids=input_ids,
1247+
hidden_states=eagle_module_input_hidden_states,
1248+
attention_mask=attention_mask,
1249+
position_ids=position_ids,
1250+
features=eagle_hidden_states_3x_pre_norm,
1251+
ttt_step=4,
1252+
parallel_draft_step=i + 1,
1253+
)
12501254

1251-
eagle_logits_3 = eagle_logits_4x[-labels.shape[1] * self.eagle_config.parallel_draft_step :]
1255+
_, eagle_logits_, eagle_hidden_states_4x_pre_norm = self._eagle_forward(
1256+
eagle_inputs_3,
1257+
output_weight,
1258+
inference_params=inference_params,
1259+
packed_seq_params=packed_seq_params,
1260+
inference_context=eagle_inference_context,
1261+
**(extra_block_kwargs or {}),
1262+
)
1263+
eagle_logits_3.append(eagle_logits_[-input_ids.shape[1] :])
1264+
eagle_logits_3 = torch.cat(eagle_logits_3, dim=0)
12521265

12531266
for i in range(self.eagle_config.parallel_draft_step):
1254-
eagle_logits = eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1267+
eagle_logits = eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
12551268
loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits)
12561269
loss_ = loss_[:, i + 3 :]
12571270
loss[:, i + 4 :] += self.eagle_loss_decay_factor**4 * loss_
@@ -1261,7 +1274,7 @@ def forward(
12611274
with torch.no_grad():
12621275
for i in range(self.eagle_config.parallel_draft_step):
12631276
gathered_logits = gather_from_tensor_model_parallel_region(
1264-
eagle_logits_3[i * labels.shape[1] : (i + 1) * labels.shape[1]]
1277+
eagle_logits_3[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]]
12651278
)
12661279
gathered_logits = gathered_logits[i + 3 : -1]
12671280
eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1)

0 commit comments

Comments
 (0)