@@ -760,6 +760,7 @@ def _get_eagle_module_inputs(
760760 position_ids : torch .Tensor ,
761761 features : torch .Tensor | None = None ,
762762 ttt_step : int = 1 ,
763+ parallel_draft_step : int = 1 ,
763764 ):
764765 """Getting EAGLE module inputs."""
765766 b = hidden_states .shape [1 ]
@@ -801,7 +802,7 @@ def _get_eagle_module_inputs(
801802 )
802803
803804 for step in range (ttt_step ):
804- for i in range (self . eagle_config . parallel_draft_step ):
805+ for i in range (parallel_draft_step ):
805806 eagle_inputs ["input_ids" ] = torch .cat (
806807 (
807808 eagle_inputs ["input_ids" ],
@@ -818,11 +819,7 @@ def _get_eagle_module_inputs(
818819 )
819820
820821 if step > 0 :
821- feature = gathered_features [
822- (step * self .eagle_config .parallel_draft_step - 1 ) * s : step
823- * self .eagle_config .parallel_draft_step
824- * s
825- ]
822+ feature = gathered_features [- s :]
826823 eagle_inputs ["hidden_states" ] = torch .cat (
827824 (
828825 eagle_inputs ["hidden_states" ],
@@ -857,7 +854,7 @@ def _get_eagle_module_inputs(
857854 )
858855
859856 eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
860- attn_mask , ttt_step * self .eagle_config .parallel_draft_step
857+ attn_mask , ( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
861858 )
862859
863860 eagle_inputs ["embedding" ] = self .embedding (
@@ -1072,21 +1069,28 @@ def forward(
10721069 # In calibration mode, we want to make sure all weights have been exercised.
10731070 # This makes sure all quantized weights have amax calibrated
10741071 if inference_params is None or self .calibration_mode :
1075- eagle_inputs_0 = self ._get_eagle_module_inputs (
1076- input_ids = input_ids ,
1077- hidden_states = eagle_module_input_hidden_states ,
1078- attention_mask = attention_mask ,
1079- position_ids = position_ids ,
1080- )
1072+ eagle_logits_0 = []
1073+ for i in range (self .eagle_config .parallel_draft_step ):
1074+ eagle_inputs_0 = self ._get_eagle_module_inputs (
1075+ input_ids = input_ids ,
1076+ hidden_states = eagle_module_input_hidden_states ,
1077+ attention_mask = attention_mask ,
1078+ position_ids = position_ids ,
1079+ ttt_step = 1 ,
1080+ parallel_draft_step = i + 1 ,
1081+ )
10811082
1082- _ , eagle_logits_0 , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1083- eagle_inputs_0 ,
1084- output_weight ,
1085- inference_params = inference_params ,
1086- packed_seq_params = packed_seq_params ,
1087- inference_context = eagle_inference_context ,
1088- ** (extra_block_kwargs or {}),
1089- )
1083+ _ , eagle_logits_ , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1084+ eagle_inputs_0 ,
1085+ output_weight ,
1086+ inference_params = inference_params ,
1087+ packed_seq_params = packed_seq_params ,
1088+ inference_context = eagle_inference_context ,
1089+ ** (extra_block_kwargs or {}),
1090+ )
1091+
1092+ eagle_logits_0 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1093+ eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
10901094
10911095 # If labels are not provided, return the original logits. We only return after
10921096 # all eagle weights have been exercised for quantization calibration purpose.
@@ -1109,9 +1113,8 @@ def forward(
11091113 loss = self .compute_language_model_loss (labels , logits_sbh )
11101114 loss = 0.0 * loss
11111115
1112- eagle_logits_0 = eagle_logits_0 [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
11131116 for i in range (self .eagle_config .parallel_draft_step ):
1114- eagle_logits = eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1117+ eagle_logits = eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
11151118 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
11161119 loss_ = loss_ [:, i :]
11171120 loss [:, i + 1 :] += self .eagle_loss_decay_factor * loss_
@@ -1121,7 +1124,7 @@ def forward(
11211124 with torch .no_grad ():
11221125 for i in range (self .eagle_config .parallel_draft_step ):
11231126 gathered_logits = gather_from_tensor_model_parallel_region (
1124- eagle_logits_0 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1127+ eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
11251128 )
11261129 gathered_logits = gathered_logits [i :- 1 ]
11271130 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1137,27 +1140,31 @@ def forward(
11371140 )
11381141
11391142 # Second round of EAGLE loss
1140- eagle_inputs_1 = self ._get_eagle_module_inputs (
1141- input_ids = input_ids ,
1142- hidden_states = eagle_module_input_hidden_states ,
1143- attention_mask = attention_mask ,
1144- position_ids = position_ids ,
1145- features = eagle_hidden_states_0_pre_norm ,
1146- ttt_step = 2 ,
1147- )
1143+ eagle_logits_1 = []
1144+ for i in range (self .eagle_config .parallel_draft_step ):
1145+ eagle_inputs_1 = self ._get_eagle_module_inputs (
1146+ input_ids = input_ids ,
1147+ hidden_states = eagle_module_input_hidden_states ,
1148+ attention_mask = attention_mask ,
1149+ position_ids = position_ids ,
1150+ features = eagle_hidden_states_0_pre_norm ,
1151+ ttt_step = 2 ,
1152+ parallel_draft_step = i + 1 ,
1153+ )
11481154
1149- _ , eagle_logits_2x , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1150- eagle_inputs_1 ,
1151- output_weight ,
1152- inference_params = inference_params ,
1153- packed_seq_params = packed_seq_params ,
1154- inference_context = eagle_inference_context ,
1155- ** (extra_block_kwargs or {}),
1156- )
1157- eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1155+ _ , eagle_logits_ , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1156+ eagle_inputs_1 ,
1157+ output_weight ,
1158+ inference_params = inference_params ,
1159+ packed_seq_params = packed_seq_params ,
1160+ inference_context = eagle_inference_context ,
1161+ ** (extra_block_kwargs or {}),
1162+ )
1163+ eagle_logits_1 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1164+ eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
11581165
11591166 for i in range (self .eagle_config .parallel_draft_step ):
1160- eagle_logits = eagle_logits_1 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1167+ eagle_logits = eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
11611168 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
11621169 loss_ = loss_ [:, i + 1 :]
11631170 loss [:, i + 2 :] += self .eagle_loss_decay_factor ** 2 * loss_
@@ -1167,7 +1174,7 @@ def forward(
11671174 with torch .no_grad ():
11681175 for i in range (self .eagle_config .parallel_draft_step ):
11691176 gathered_logits = gather_from_tensor_model_parallel_region (
1170- eagle_logits_1 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1177+ eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
11711178 )
11721179 gathered_logits = gathered_logits [i + 1 : - 1 ]
11731180 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1183,28 +1190,31 @@ def forward(
11831190 )
11841191
11851192 # Third EAGLE loss
1186- eagle_inputs_2 = self ._get_eagle_module_inputs (
1187- input_ids = input_ids ,
1188- hidden_states = eagle_module_input_hidden_states ,
1189- attention_mask = attention_mask ,
1190- position_ids = position_ids ,
1191- features = eagle_hidden_states_2x_pre_norm ,
1192- ttt_step = 3 ,
1193- )
1194-
1195- _ , eagle_logits_3x , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1196- eagle_inputs_2 ,
1197- output_weight ,
1198- inference_params = inference_params ,
1199- packed_seq_params = packed_seq_params ,
1200- inference_context = eagle_inference_context ,
1201- ** (extra_block_kwargs or {}),
1202- )
1193+ eagle_logits_2 = []
1194+ for i in range (self .eagle_config .parallel_draft_step ):
1195+ eagle_inputs_2 = self ._get_eagle_module_inputs (
1196+ input_ids = input_ids ,
1197+ hidden_states = eagle_module_input_hidden_states ,
1198+ attention_mask = attention_mask ,
1199+ position_ids = position_ids ,
1200+ features = eagle_hidden_states_2x_pre_norm ,
1201+ ttt_step = 3 ,
1202+ parallel_draft_step = i + 1 ,
1203+ )
12031204
1204- eagle_logits_2 = eagle_logits_3x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1205+ _ , eagle_logits_ , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1206+ eagle_inputs_2 ,
1207+ output_weight ,
1208+ inference_params = inference_params ,
1209+ packed_seq_params = packed_seq_params ,
1210+ inference_context = eagle_inference_context ,
1211+ ** (extra_block_kwargs or {}),
1212+ )
1213+ eagle_logits_2 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1214+ eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
12051215
12061216 for i in range (self .eagle_config .parallel_draft_step ):
1207- eagle_logits = eagle_logits_2 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1217+ eagle_logits = eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
12081218 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
12091219 loss_ = loss_ [:, i + 2 :]
12101220 loss [:, i + 3 :] += self .eagle_loss_decay_factor ** 3 * loss_
@@ -1214,7 +1224,7 @@ def forward(
12141224 with torch .no_grad ():
12151225 for i in range (self .eagle_config .parallel_draft_step ):
12161226 gathered_logits = gather_from_tensor_model_parallel_region (
1217- eagle_logits_2 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1227+ eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
12181228 )
12191229 gathered_logits = gathered_logits [i + 2 : - 1 ]
12201230 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
@@ -1230,28 +1240,31 @@ def forward(
12301240 )
12311241
12321242 # Forth EAGLE loss
1233- eagle_inputs_3 = self ._get_eagle_module_inputs (
1234- input_ids = input_ids ,
1235- hidden_states = eagle_module_input_hidden_states ,
1236- attention_mask = attention_mask ,
1237- position_ids = position_ids ,
1238- features = eagle_hidden_states_3x_pre_norm ,
1239- ttt_step = 4 ,
1240- )
1241-
1242- _ , eagle_logits_4x , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1243- eagle_inputs_3 ,
1244- output_weight ,
1245- inference_params = inference_params ,
1246- packed_seq_params = packed_seq_params ,
1247- inference_context = eagle_inference_context ,
1248- ** (extra_block_kwargs or {}),
1249- )
1243+ eagle_logits_3 = []
1244+ for i in range (self .eagle_config .parallel_draft_step ):
1245+ eagle_inputs_3 = self ._get_eagle_module_inputs (
1246+ input_ids = input_ids ,
1247+ hidden_states = eagle_module_input_hidden_states ,
1248+ attention_mask = attention_mask ,
1249+ position_ids = position_ids ,
1250+ features = eagle_hidden_states_3x_pre_norm ,
1251+ ttt_step = 4 ,
1252+ parallel_draft_step = i + 1 ,
1253+ )
12501254
1251- eagle_logits_3 = eagle_logits_4x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
1255+ _ , eagle_logits_ , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1256+ eagle_inputs_3 ,
1257+ output_weight ,
1258+ inference_params = inference_params ,
1259+ packed_seq_params = packed_seq_params ,
1260+ inference_context = eagle_inference_context ,
1261+ ** (extra_block_kwargs or {}),
1262+ )
1263+ eagle_logits_3 .append (eagle_logits_ [- input_ids .shape [1 ] :])
1264+ eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
12521265
12531266 for i in range (self .eagle_config .parallel_draft_step ):
1254- eagle_logits = eagle_logits_3 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1267+ eagle_logits = eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
12551268 loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
12561269 loss_ = loss_ [:, i + 3 :]
12571270 loss [:, i + 4 :] += self .eagle_loss_decay_factor ** 4 * loss_
@@ -1261,7 +1274,7 @@ def forward(
12611274 with torch .no_grad ():
12621275 for i in range (self .eagle_config .parallel_draft_step ):
12631276 gathered_logits = gather_from_tensor_model_parallel_region (
1264- eagle_logits_3 [i * labels .shape [1 ] : (i + 1 ) * labels .shape [1 ]]
1277+ eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
12651278 )
12661279 gathered_logits = gathered_logits [i + 3 : - 1 ]
12671280 eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
0 commit comments