@@ -760,6 +760,7 @@ def _get_eagle_module_inputs(
760
760
position_ids : torch .Tensor ,
761
761
features : torch .Tensor | None = None ,
762
762
ttt_step : int = 1 ,
763
+ parallel_draft_step : int = 1 ,
763
764
):
764
765
"""Getting EAGLE module inputs."""
765
766
b = hidden_states .shape [1 ]
@@ -801,7 +802,7 @@ def _get_eagle_module_inputs(
801
802
)
802
803
803
804
for step in range (ttt_step ):
804
- for i in range (self . eagle_config . parallel_draft_step ):
805
+ for i in range (parallel_draft_step ):
805
806
eagle_inputs ["input_ids" ] = torch .cat (
806
807
(
807
808
eagle_inputs ["input_ids" ],
@@ -818,11 +819,7 @@ def _get_eagle_module_inputs(
818
819
)
819
820
820
821
if step > 0 :
821
- feature = gathered_features [
822
- (step * self .eagle_config .parallel_draft_step - 1 ) * s : step
823
- * self .eagle_config .parallel_draft_step
824
- * s
825
- ]
822
+ feature = gathered_features [- s :]
826
823
eagle_inputs ["hidden_states" ] = torch .cat (
827
824
(
828
825
eagle_inputs ["hidden_states" ],
@@ -857,7 +854,7 @@ def _get_eagle_module_inputs(
857
854
)
858
855
859
856
eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
860
- attn_mask , ttt_step * self .eagle_config .parallel_draft_step
857
+ attn_mask , ( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
861
858
)
862
859
863
860
eagle_inputs ["embedding" ] = self .embedding (
@@ -1072,21 +1069,28 @@ def forward(
1072
1069
# In calibration mode, we want to make sure all weights have been exercised.
1073
1070
# This makes sure all quantized weights have amax calibrated
1074
1071
if inference_params is None or self .calibration_mode :
1075
- eagle_inputs_0 = self ._get_eagle_module_inputs (
1076
- input_ids = input_ids ,
1077
- hidden_states = eagle_module_input_hidden_states ,
1078
- attention_mask = attention_mask ,
1079
- position_ids = position_ids ,
1080
- )
1072
+ eagle_logits_0 = []
1073
+ for i in range (self .eagle_config .parallel_draft_step ):
1074
+ eagle_inputs_0 = self ._get_eagle_module_inputs (
1075
+ input_ids = input_ids ,
1076
+ hidden_states = eagle_module_input_hidden_states ,
1077
+ attention_mask = attention_mask ,
1078
+ position_ids = position_ids ,
1079
+ ttt_step = 1 ,
1080
+ parallel_draft_step = i + 1 ,
1081
+ )
1081
1082
1082
- _ , eagle_logits_0 , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1083
- eagle_inputs_0 ,
1084
- output_weight ,
1085
- inference_params = inference_params ,
1086
- packed_seq_params = packed_seq_params ,
1087
- inference_context = eagle_inference_context ,
1088
- ** (extra_block_kwargs or {}),
1089
- )
1083
+ _ , eagle_logits_ , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1084
+ eagle_inputs_0 ,
1085
+ output_weight ,
1086
+ inference_params = inference_params ,
1087
+ packed_seq_params = packed_seq_params ,
1088
+ inference_context = eagle_inference_context ,
1089
+ ** (extra_block_kwargs or {}),
1090
+ )
1091
+
1092
+ eagle_logits_0 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1093
+ eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
1090
1094
1091
1095
# If labels are not provided, return the original logits. We only return after
1092
1096
# all eagle weights have been exercised for quantization calibration purpose.
@@ -1109,9 +1113,8 @@ def forward(
1109
1113
loss = self .compute_language_model_loss (labels , logits_sbh )
1110
1114
loss = 0.0 * loss
1111
1115
1112
- eagle_logits_0 = eagle_logits_0 [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1113
1116
for i in range (self .eagle_config .parallel_draft_step ):
1114
- eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1117
+ eagle_logits = eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1115
1118
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1116
1119
loss_ = loss_ [:, i :]
1117
1120
loss [:, i + 1 :] += self .eagle_loss_decay_factor * loss_
@@ -1121,7 +1124,7 @@ def forward(
1121
1124
with torch .no_grad ():
1122
1125
for i in range (self .eagle_config .parallel_draft_step ):
1123
1126
gathered_logits = gather_from_tensor_model_parallel_region (
1124
- eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1127
+ eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1125
1128
)
1126
1129
gathered_logits = gathered_logits [i :- 1 ]
1127
1130
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1137,27 +1140,31 @@ def forward(
1137
1140
)
1138
1141
1139
1142
# Second round of EAGLE loss
1140
- eagle_inputs_1 = self ._get_eagle_module_inputs (
1141
- input_ids = input_ids ,
1142
- hidden_states = eagle_module_input_hidden_states ,
1143
- attention_mask = attention_mask ,
1144
- position_ids = position_ids ,
1145
- features = eagle_hidden_states_0_pre_norm ,
1146
- ttt_step = 2 ,
1147
- )
1143
+ eagle_logits_1 = []
1144
+ for i in range (self .eagle_config .parallel_draft_step ):
1145
+ eagle_inputs_1 = self ._get_eagle_module_inputs (
1146
+ input_ids = input_ids ,
1147
+ hidden_states = eagle_module_input_hidden_states ,
1148
+ attention_mask = attention_mask ,
1149
+ position_ids = position_ids ,
1150
+ features = eagle_hidden_states_0_pre_norm ,
1151
+ ttt_step = 2 ,
1152
+ parallel_draft_step = i + 1 ,
1153
+ )
1148
1154
1149
- _ , eagle_logits_2x , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1150
- eagle_inputs_1 ,
1151
- output_weight ,
1152
- inference_params = inference_params ,
1153
- packed_seq_params = packed_seq_params ,
1154
- inference_context = eagle_inference_context ,
1155
- ** (extra_block_kwargs or {}),
1156
- )
1157
- eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1155
+ _ , eagle_logits_ , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1156
+ eagle_inputs_1 ,
1157
+ output_weight ,
1158
+ inference_params = inference_params ,
1159
+ packed_seq_params = packed_seq_params ,
1160
+ inference_context = eagle_inference_context ,
1161
+ ** (extra_block_kwargs or {}),
1162
+ )
1163
+ eagle_logits_1 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1164
+ eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
1158
1165
1159
1166
for i in range (self .eagle_config .parallel_draft_step ):
1160
- eagle_logits = eagle_logits_1 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1167
+ eagle_logits = eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1161
1168
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1162
1169
loss_ = loss_ [:, i + 1 :]
1163
1170
loss [:, i + 2 :] += self .eagle_loss_decay_factor ** 2 * loss_
@@ -1167,7 +1174,7 @@ def forward(
1167
1174
with torch .no_grad ():
1168
1175
for i in range (self .eagle_config .parallel_draft_step ):
1169
1176
gathered_logits = gather_from_tensor_model_parallel_region (
1170
- eagle_logits_1 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1177
+ eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1171
1178
)
1172
1179
gathered_logits = gathered_logits [i + 1 : - 1 ]
1173
1180
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1183,28 +1190,31 @@ def forward(
1183
1190
)
1184
1191
1185
1192
# Third EAGLE loss
1186
- eagle_inputs_2 = self ._get_eagle_module_inputs (
1187
- input_ids = input_ids ,
1188
- hidden_states = eagle_module_input_hidden_states ,
1189
- attention_mask = attention_mask ,
1190
- position_ids = position_ids ,
1191
- features = eagle_hidden_states_2x_pre_norm ,
1192
- ttt_step = 3 ,
1193
- )
1194
-
1195
- _ , eagle_logits_3x , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1196
- eagle_inputs_2 ,
1197
- output_weight ,
1198
- inference_params = inference_params ,
1199
- packed_seq_params = packed_seq_params ,
1200
- inference_context = eagle_inference_context ,
1201
- ** (extra_block_kwargs or {}),
1202
- )
1193
+ eagle_logits_2 = []
1194
+ for i in range (self .eagle_config .parallel_draft_step ):
1195
+ eagle_inputs_2 = self ._get_eagle_module_inputs (
1196
+ input_ids = input_ids ,
1197
+ hidden_states = eagle_module_input_hidden_states ,
1198
+ attention_mask = attention_mask ,
1199
+ position_ids = position_ids ,
1200
+ features = eagle_hidden_states_2x_pre_norm ,
1201
+ ttt_step = 3 ,
1202
+ parallel_draft_step = i + 1 ,
1203
+ )
1203
1204
1204
- eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1205
+ _ , eagle_logits_ , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1206
+ eagle_inputs_2 ,
1207
+ output_weight ,
1208
+ inference_params = inference_params ,
1209
+ packed_seq_params = packed_seq_params ,
1210
+ inference_context = eagle_inference_context ,
1211
+ ** (extra_block_kwargs or {}),
1212
+ )
1213
+ eagle_logits_2 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1214
+ eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
1205
1215
1206
1216
for i in range (self .eagle_config .parallel_draft_step ):
1207
- eagle_logits = eagle_logits_2 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1217
+ eagle_logits = eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1208
1218
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1209
1219
loss_ = loss_ [:, i + 2 :]
1210
1220
loss [:, i + 3 :] += self .eagle_loss_decay_factor ** 3 * loss_
@@ -1214,7 +1224,7 @@ def forward(
1214
1224
with torch .no_grad ():
1215
1225
for i in range (self .eagle_config .parallel_draft_step ):
1216
1226
gathered_logits = gather_from_tensor_model_parallel_region (
1217
- eagle_logits_2 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1227
+ eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1218
1228
)
1219
1229
gathered_logits = gathered_logits [i + 2 : - 1 ]
1220
1230
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1230,28 +1240,31 @@ def forward(
1230
1240
)
1231
1241
1232
1242
# Forth EAGLE loss
1233
- eagle_inputs_3 = self ._get_eagle_module_inputs (
1234
- input_ids = input_ids ,
1235
- hidden_states = eagle_module_input_hidden_states ,
1236
- attention_mask = attention_mask ,
1237
- position_ids = position_ids ,
1238
- features = eagle_hidden_states_3x_pre_norm ,
1239
- ttt_step = 4 ,
1240
- )
1241
-
1242
- _ , eagle_logits_4x , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1243
- eagle_inputs_3 ,
1244
- output_weight ,
1245
- inference_params = inference_params ,
1246
- packed_seq_params = packed_seq_params ,
1247
- inference_context = eagle_inference_context ,
1248
- ** (extra_block_kwargs or {}),
1249
- )
1243
+ eagle_logits_3 = []
1244
+ for i in range (self .eagle_config .parallel_draft_step ):
1245
+ eagle_inputs_3 = self ._get_eagle_module_inputs (
1246
+ input_ids = input_ids ,
1247
+ hidden_states = eagle_module_input_hidden_states ,
1248
+ attention_mask = attention_mask ,
1249
+ position_ids = position_ids ,
1250
+ features = eagle_hidden_states_3x_pre_norm ,
1251
+ ttt_step = 4 ,
1252
+ parallel_draft_step = i + 1 ,
1253
+ )
1250
1254
1251
- eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1255
+ _ , eagle_logits_ , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1256
+ eagle_inputs_3 ,
1257
+ output_weight ,
1258
+ inference_params = inference_params ,
1259
+ packed_seq_params = packed_seq_params ,
1260
+ inference_context = eagle_inference_context ,
1261
+ ** (extra_block_kwargs or {}),
1262
+ )
1263
+ eagle_logits_3 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1264
+ eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
1252
1265
1253
1266
for i in range (self .eagle_config .parallel_draft_step ):
1254
- eagle_logits = eagle_logits_3 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1267
+ eagle_logits = eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1255
1268
loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1256
1269
loss_ = loss_ [:, i + 3 :]
1257
1270
loss [:, i + 4 :] += self .eagle_loss_decay_factor ** 4 * loss_
@@ -1261,7 +1274,7 @@ def forward(
1261
1274
with torch .no_grad ():
1262
1275
for i in range (self .eagle_config .parallel_draft_step ):
1263
1276
gathered_logits = gather_from_tensor_model_parallel_region (
1264
- eagle_logits_3 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1277
+ eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1265
1278
)
1266
1279
gathered_logits = gathered_logits [i + 3 : - 1 ]
1267
1280
eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
0 commit comments