@@ -197,6 +197,7 @@ def set_multi_step_attention_mask(attn_mask, step):
197197
198198 ttt_step=2
199199 parallel_draft_step=2
200+ ->step=3
200201
201202 | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
202203 (out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
@@ -239,13 +240,12 @@ def set_multi_step_attention_mask(attn_mask, step):
239240 =======================================================================================================================
240241 """ # noqa: E501
241242 s = attn_mask .shape [- 1 ]
242- for step_idx in range (2 , step + 1 ):
243- # step_idx starts from 2nd step
243+ for step_idx in range (step ):
244244 mask_0 = attn_mask .clone ().detach ()
245- mask_0 [:, :, step_idx - 2 , :] = True
245+ mask_0 [:, :, step_idx , :] = True
246246 mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
247247 mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
248- for i in range (step_idx - 1 , s - 1 ):
248+ for i in range (step_idx + 1 , s - 1 ):
249249 mask_1 [:, :, i , i ] = False
250250
251251 attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
@@ -759,8 +759,8 @@ def _get_eagle_module_inputs(
759759 attention_mask : torch .Tensor ,
760760 position_ids : torch .Tensor ,
761761 features : torch .Tensor | None = None ,
762- ttt_step : int = 1 ,
763- parallel_draft_step : int = 1 ,
762+ ttt_step : int = 0 ,
763+ parallel_draft_step : int = 0 ,
764764 ):
765765 """Getting EAGLE module inputs."""
766766 b = hidden_states .shape [1 ]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784784
785785 eagle_inputs ["input_ids" ] = (
786786 padded_input_ids
787- if parallel_draft_step == 1
787+ if parallel_draft_step == 0
788788 else torch .full (
789789 padded_input_ids .shape ,
790- getattr (self , f"mask_token_{ parallel_draft_step - 2 } " ),
790+ getattr (self , f"mask_token_{ parallel_draft_step - 1 } " ),
791791 device = padded_input_ids .device ,
792792 dtype = padded_input_ids .dtype ,
793793 )
@@ -805,7 +805,7 @@ def _get_eagle_module_inputs(
805805 feature = gathered_features [- s :]
806806 eagle_inputs ["hidden_states" ] = (
807807 gathered_hidden_states
808- if ttt_step == 1
808+ if ttt_step == 0
809809 else torch .cat (
810810 (
811811 torch .zeros (
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824824 )
825825
826826 eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827- attn_mask , ( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
827+ attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step
828828 )
829829
830830 eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831831 [rotary_pos_emb ]
832- * (( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step ),
832+ * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step + 1 ),
833833 dim = 0 ,
834834 )
835835
@@ -970,6 +970,7 @@ def forward(
970970 packed_seq_params : PackedSeqParams = None ,
971971 extra_block_kwargs : dict | None = None ,
972972 return_eagle_inputs : bool = False ,
973+ ttt_steps = 4 ,
973974 ** kwargs ,
974975 ) -> torch .Tensor :
975976 if position_ids is None or attention_mask is None :
@@ -1013,7 +1014,8 @@ def forward(
10131014
10141015 # EAGLE kv cache
10151016 eagle_inference_context = StaticInferenceContext (
1016- input_ids .shape [0 ], input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * 4
1017+ input_ids .shape [0 ],
1018+ input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * ttt_steps ,
10171019 )
10181020
10191021 if self .eagle_offline :
@@ -1043,228 +1045,88 @@ def forward(
10431045 hidden_states , apply_fc = True
10441046 )
10451047
1046- # In calibration mode, we want to make sure all weights have been exercised.
1047- # This makes sure all quantized weights have amax calibrated
1048- if inference_params is None or self .calibration_mode :
1049- eagle_logits_0 = []
1048+ if labels is not None :
1049+ if labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1050+ # For offline training, labels may be 1 token shorter than input_ids.
1051+ # We will just pad a 0 to the labels to make the seq_len the same as
1052+ # input_ids. This will introduce a small error in training if logit_distillation
1053+ # is False, and testing accuracy is wrong for the last token.
1054+ right_token_pad = torch .zeros (
1055+ (labels .shape [0 ], 1 ),
1056+ dtype = labels .dtype ,
1057+ device = labels .device ,
1058+ )
1059+ labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1060+
1061+ # If eagle_freeze_base_model is set to True,
1062+ # the base model is frozen .
1063+ loss = self .compute_language_model_loss (labels , logits_sbh )
1064+ if self .eagle_freeze_base_model :
1065+ loss = 0.0 * loss
1066+
1067+ eagle_hidden_states_pre_norm = None
1068+ for ttt_step in range (ttt_steps ):
1069+ eagle_logits = []
10501070 for i in range (self .eagle_config .parallel_draft_step ):
1051- eagle_inputs_0 = self ._get_eagle_module_inputs (
1071+ eagle_inputs = self ._get_eagle_module_inputs (
10521072 input_ids = input_ids ,
10531073 hidden_states = eagle_module_input_hidden_states ,
10541074 attention_mask = attention_mask ,
10551075 position_ids = position_ids ,
1056- ttt_step = 1 ,
1057- parallel_draft_step = i + 1 ,
1076+ features = eagle_hidden_states_pre_norm ,
1077+ ttt_step = ttt_step ,
1078+ parallel_draft_step = i ,
10581079 )
10591080
1060- _ , eagle_logits_ , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1061- eagle_inputs_0 ,
1081+ _ , eagle_logits_ , eagle_hidden_states_pre_norm_ = self ._eagle_forward (
1082+ eagle_inputs ,
10621083 output_weight ,
10631084 inference_params = inference_params ,
10641085 packed_seq_params = packed_seq_params ,
10651086 inference_context = eagle_inference_context ,
10661087 ** (extra_block_kwargs or {}),
10671088 )
10681089
1069- eagle_logits_0 .append (eagle_logits_ )
1070- eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
1071-
1072- # If labels are not provided, return the original logits. We only return after
1073- # all eagle weights have been exercised for quantization calibration purpose.
1074- if labels is None :
1075- return logits_sbh .transpose (0 , 1 ).contiguous ()
1076- elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1077- # For offline training, labels may be 1 token shorter than input_ids.
1078- # We will just pad a 0 to the labels to make the seq_len the same as
1079- # input_ids. This will introduce a small error in training if logit_distillation
1080- # is False, and testing accuracy is wrong for the last token.
1081- right_token_pad = torch .zeros (
1082- (labels .shape [0 ], 1 ),
1083- dtype = labels .dtype ,
1084- device = labels .device ,
1085- )
1086- labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1087-
1088- # If eagle_freeze_base_model is set to True,
1089- # the base model is frozen .
1090- loss = self .compute_language_model_loss (labels , logits_sbh )
1091- loss = 0.0 * loss
1092-
1093- for i in range (self .eagle_config .parallel_draft_step ):
1094- eagle_logits = eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1095- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1096- loss_ = loss_ [:, i :]
1097- loss [:, i + 1 :] += self .eagle_loss_decay_factor * loss_
1098-
1099- if self .eagle_report_acc and not self .training :
1100- acc = []
1101- with torch .no_grad ():
1102- for i in range (self .eagle_config .parallel_draft_step ):
1103- gathered_logits = gather_from_tensor_model_parallel_region (
1104- eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1105- )
1106- gathered_logits = gathered_logits [i :- 1 ]
1107- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1108- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1109- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1110- top1_p = torch .eq (labels [:, i + 1 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1111- acc .append (top1_p )
1112-
1113- if get_tensor_model_parallel_rank () == 0 :
1114- print (
1115- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 1st Top-1: { acc } " ,
1116- flush = True ,
1117- )
1118-
1119- # Second round of EAGLE loss
1120- eagle_logits_1 = []
1121- for i in range (self .eagle_config .parallel_draft_step ):
1122- eagle_inputs_1 = self ._get_eagle_module_inputs (
1123- input_ids = input_ids ,
1124- hidden_states = eagle_module_input_hidden_states ,
1125- attention_mask = attention_mask ,
1126- position_ids = position_ids ,
1127- features = eagle_hidden_states_0_pre_norm ,
1128- ttt_step = 2 ,
1129- parallel_draft_step = i + 1 ,
1130- )
1090+ eagle_logits .append (eagle_logits_ )
1091+ eagle_logits = torch .cat (eagle_logits , dim = 0 )
1092+ eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
11311093
1132- _ , eagle_logits_ , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1133- eagle_inputs_1 ,
1134- output_weight ,
1135- inference_params = inference_params ,
1136- packed_seq_params = packed_seq_params ,
1137- inference_context = eagle_inference_context ,
1138- ** (extra_block_kwargs or {}),
1139- )
1140- eagle_logits_1 .append (eagle_logits_ )
1141- eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
1142-
1143- for i in range (self .eagle_config .parallel_draft_step ):
1144- eagle_logits = eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1145- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1146- loss_ = loss_ [:, i + 1 :]
1147- loss [:, i + 2 :] += self .eagle_loss_decay_factor ** 2 * loss_
1148-
1149- if self .eagle_report_acc and not self .training :
1150- acc = []
1151- with torch .no_grad ():
1152- for i in range (self .eagle_config .parallel_draft_step ):
1153- gathered_logits = gather_from_tensor_model_parallel_region (
1154- eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1155- )
1156- gathered_logits = gathered_logits [i + 1 : - 1 ]
1157- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1158- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1159- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1160- top1_p = torch .eq (labels [:, i + 2 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1161- acc .append (top1_p )
1162-
1163- if get_tensor_model_parallel_rank () == 0 :
1164- print (
1165- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 2nd Top-1: { acc } " ,
1166- flush = True ,
1167- )
1094+ # If labels are not provided, return the original logits. We only return after
1095+ # all eagle weights have been exercised for quantization calibration purpose.
1096+ if labels is None :
1097+ return logits_sbh .transpose (0 , 1 ).contiguous ()
11681098
1169- # Third EAGLE loss
1170- eagle_logits_2 = []
1171- for i in range (self .eagle_config .parallel_draft_step ):
1172- eagle_inputs_2 = self ._get_eagle_module_inputs (
1173- input_ids = input_ids ,
1174- hidden_states = eagle_module_input_hidden_states ,
1175- attention_mask = attention_mask ,
1176- position_ids = position_ids ,
1177- features = eagle_hidden_states_2x_pre_norm ,
1178- ttt_step = 3 ,
1179- parallel_draft_step = i + 1 ,
1180- )
1181-
1182- _ , eagle_logits_ , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1183- eagle_inputs_2 ,
1184- output_weight ,
1185- inference_params = inference_params ,
1186- packed_seq_params = packed_seq_params ,
1187- inference_context = eagle_inference_context ,
1188- ** (extra_block_kwargs or {}),
1189- )
1190- eagle_logits_2 .append (eagle_logits_ )
1191- eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
1192-
1193- for i in range (self .eagle_config .parallel_draft_step ):
1194- eagle_logits = eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1195- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1196- loss_ = loss_ [:, i + 2 :]
1197- loss [:, i + 3 :] += self .eagle_loss_decay_factor ** 3 * loss_
1198-
1199- if self .eagle_report_acc and not self .training :
1200- acc = []
1201- with torch .no_grad ():
1202- for i in range (self .eagle_config .parallel_draft_step ):
1203- gathered_logits = gather_from_tensor_model_parallel_region (
1204- eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1205- )
1206- gathered_logits = gathered_logits [i + 2 : - 1 ]
1207- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1208- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1209- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1210- top1_p = torch .eq (labels [:, i + 3 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1211- acc .append (top1_p )
1212-
1213- if get_tensor_model_parallel_rank () == 0 :
1214- print (
1215- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 3rd Top-1: { acc } " ,
1216- flush = True ,
1099+ for i in range (self .eagle_config .parallel_draft_step ):
1100+ eagle_logit = eagle_logits [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1101+ loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logit )
1102+ loss_ = loss_ [:, i + ttt_step :]
1103+ loss [:, i + ttt_step + 1 :] += (
1104+ self .eagle_loss_decay_factor ** (ttt_step + i ) * loss_
12171105 )
12181106
1219- # Forth EAGLE loss
1220- eagle_logits_3 = []
1221- for i in range (self .eagle_config .parallel_draft_step ):
1222- eagle_inputs_3 = self ._get_eagle_module_inputs (
1223- input_ids = input_ids ,
1224- hidden_states = eagle_module_input_hidden_states ,
1225- attention_mask = attention_mask ,
1226- position_ids = position_ids ,
1227- features = eagle_hidden_states_3x_pre_norm ,
1228- ttt_step = 4 ,
1229- parallel_draft_step = i + 1 ,
1230- )
1231-
1232- _ , eagle_logits_ , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1233- eagle_inputs_3 ,
1234- output_weight ,
1235- inference_params = inference_params ,
1236- packed_seq_params = packed_seq_params ,
1237- inference_context = eagle_inference_context ,
1238- ** (extra_block_kwargs or {}),
1239- )
1240- eagle_logits_3 .append (eagle_logits_ )
1241- eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
1242-
1243- for i in range (self .eagle_config .parallel_draft_step ):
1244- eagle_logits = eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1245- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1246- loss_ = loss_ [:, i + 3 :]
1247- loss [:, i + 4 :] += self .eagle_loss_decay_factor ** 4 * loss_
1248-
1249- if self .eagle_report_acc and not self .training :
1250- acc = []
1251- with torch .no_grad ():
1252- for i in range (self .eagle_config .parallel_draft_step ):
1253- gathered_logits = gather_from_tensor_model_parallel_region (
1254- eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1107+ if self .eagle_report_acc and not self .training :
1108+ acc = []
1109+ with torch .no_grad ():
1110+ for i in range (self .eagle_config .parallel_draft_step ):
1111+ gathered_logits = gather_from_tensor_model_parallel_region (
1112+ eagle_logits [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1113+ )
1114+ gathered_logits = gathered_logits [i + ttt_step : - 1 ]
1115+ eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1116+ if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1117+ eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1118+ top1_p = (
1119+ torch .eq (labels [:, i + ttt_step + 1 :], eagle_top1 ).sum ()
1120+ / eagle_top1 .numel ()
1121+ )
1122+ acc .append (top1_p )
1123+
1124+ if get_tensor_model_parallel_rank () == 0 :
1125+ print (
1126+ f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} "
1127+ f"EAGLE 1st Top-1: { acc } " ,
1128+ flush = True ,
12551129 )
1256- gathered_logits = gathered_logits [i + 3 : - 1 ]
1257- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1258- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1259- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1260- top1_p = torch .eq (labels [:, i + 4 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1261- acc .append (top1_p )
1262-
1263- if get_tensor_model_parallel_rank () == 0 :
1264- print (
1265- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 4th Top-1: { acc } " ,
1266- flush = True ,
1267- )
12681130
12691131 return loss
12701132
0 commit comments