@@ -197,6 +197,7 @@ def set_multi_step_attention_mask(attn_mask, step):
197
197
198
198
ttt_step=2
199
199
parallel_draft_step=2
200
+ ->step=3
200
201
201
202
| i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- |
202
203
(out) | h0 h1 h2 h3 h4 h5 h6 h7 | h0 h1 h2 h3 h4 h5 h6 h7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- G2 G3 G4 G5 G6 G7 |
@@ -239,13 +240,12 @@ def set_multi_step_attention_mask(attn_mask, step):
239
240
=======================================================================================================================
240
241
""" # noqa: E501
241
242
s = attn_mask .shape [- 1 ]
242
- for step_idx in range (2 , step + 1 ):
243
- # step_idx starts from 2nd step
243
+ for step_idx in range (step ):
244
244
mask_0 = attn_mask .clone ().detach ()
245
- mask_0 [:, :, step_idx - 2 , :] = True
245
+ mask_0 [:, :, step_idx , :] = True
246
246
mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
247
247
mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
248
- for i in range (step_idx - 1 , s - 1 ):
248
+ for i in range (step_idx + 1 , s - 1 ):
249
249
mask_1 [:, :, i , i ] = False
250
250
251
251
attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
@@ -759,8 +759,8 @@ def _get_eagle_module_inputs(
759
759
attention_mask : torch .Tensor ,
760
760
position_ids : torch .Tensor ,
761
761
features : torch .Tensor | None = None ,
762
- ttt_step : int = 1 ,
763
- parallel_draft_step : int = 1 ,
762
+ ttt_step : int = 0 ,
763
+ parallel_draft_step : int = 0 ,
764
764
):
765
765
"""Getting EAGLE module inputs."""
766
766
b = hidden_states .shape [1 ]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784
784
785
785
eagle_inputs ["input_ids" ] = (
786
786
padded_input_ids
787
- if parallel_draft_step == 1
787
+ if parallel_draft_step == 0
788
788
else torch .full (
789
789
padded_input_ids .shape ,
790
- getattr (self , f"mask_token_{ parallel_draft_step - 2 } " ),
790
+ getattr (self , f"mask_token_{ parallel_draft_step - 1 } " ),
791
791
device = padded_input_ids .device ,
792
792
dtype = padded_input_ids .dtype ,
793
793
)
@@ -805,7 +805,7 @@ def _get_eagle_module_inputs(
805
805
feature = gathered_features [- s :]
806
806
eagle_inputs ["hidden_states" ] = (
807
807
gathered_hidden_states
808
- if ttt_step == 1
808
+ if ttt_step == 0
809
809
else torch .cat (
810
810
(
811
811
torch .zeros (
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824
824
)
825
825
826
826
eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827
- attn_mask , ( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
827
+ attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step
828
828
)
829
829
830
830
eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831
831
[rotary_pos_emb ]
832
- * (( ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step ),
832
+ * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step + 1 ),
833
833
dim = 0 ,
834
834
)
835
835
@@ -970,6 +970,7 @@ def forward(
970
970
packed_seq_params : PackedSeqParams = None ,
971
971
extra_block_kwargs : dict | None = None ,
972
972
return_eagle_inputs : bool = False ,
973
+ ttt_steps = 4 ,
973
974
** kwargs ,
974
975
) -> torch .Tensor :
975
976
if position_ids is None or attention_mask is None :
@@ -1013,7 +1014,8 @@ def forward(
1013
1014
1014
1015
# EAGLE kv cache
1015
1016
eagle_inference_context = StaticInferenceContext (
1016
- input_ids .shape [0 ], input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * 4
1017
+ input_ids .shape [0 ],
1018
+ input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * ttt_steps ,
1017
1019
)
1018
1020
1019
1021
if self .eagle_offline :
@@ -1043,228 +1045,88 @@ def forward(
1043
1045
hidden_states , apply_fc = True
1044
1046
)
1045
1047
1046
- # In calibration mode, we want to make sure all weights have been exercised.
1047
- # This makes sure all quantized weights have amax calibrated
1048
- if inference_params is None or self .calibration_mode :
1049
- eagle_logits_0 = []
1048
+ if labels is not None :
1049
+ if labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1050
+ # For offline training, labels may be 1 token shorter than input_ids.
1051
+ # We will just pad a 0 to the labels to make the seq_len the same as
1052
+ # input_ids. This will introduce a small error in training if logit_distillation
1053
+ # is False, and testing accuracy is wrong for the last token.
1054
+ right_token_pad = torch .zeros (
1055
+ (labels .shape [0 ], 1 ),
1056
+ dtype = labels .dtype ,
1057
+ device = labels .device ,
1058
+ )
1059
+ labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1060
+
1061
+ # If eagle_freeze_base_model is set to True,
1062
+ # the base model is frozen .
1063
+ loss = self .compute_language_model_loss (labels , logits_sbh )
1064
+ if self .eagle_freeze_base_model :
1065
+ loss = 0.0 * loss
1066
+
1067
+ eagle_hidden_states_pre_norm = None
1068
+ for ttt_step in range (ttt_steps ):
1069
+ eagle_logits = []
1050
1070
for i in range (self .eagle_config .parallel_draft_step ):
1051
- eagle_inputs_0 = self ._get_eagle_module_inputs (
1071
+ eagle_inputs = self ._get_eagle_module_inputs (
1052
1072
input_ids = input_ids ,
1053
1073
hidden_states = eagle_module_input_hidden_states ,
1054
1074
attention_mask = attention_mask ,
1055
1075
position_ids = position_ids ,
1056
- ttt_step = 1 ,
1057
- parallel_draft_step = i + 1 ,
1076
+ features = eagle_hidden_states_pre_norm ,
1077
+ ttt_step = ttt_step ,
1078
+ parallel_draft_step = i ,
1058
1079
)
1059
1080
1060
- _ , eagle_logits_ , eagle_hidden_states_0_pre_norm = self ._eagle_forward (
1061
- eagle_inputs_0 ,
1081
+ _ , eagle_logits_ , eagle_hidden_states_pre_norm_ = self ._eagle_forward (
1082
+ eagle_inputs ,
1062
1083
output_weight ,
1063
1084
inference_params = inference_params ,
1064
1085
packed_seq_params = packed_seq_params ,
1065
1086
inference_context = eagle_inference_context ,
1066
1087
** (extra_block_kwargs or {}),
1067
1088
)
1068
1089
1069
- eagle_logits_0 .append (eagle_logits_ )
1070
- eagle_logits_0 = torch .cat (eagle_logits_0 , dim = 0 )
1071
-
1072
- # If labels are not provided, return the original logits. We only return after
1073
- # all eagle weights have been exercised for quantization calibration purpose.
1074
- if labels is None :
1075
- return logits_sbh .transpose (0 , 1 ).contiguous ()
1076
- elif labels .shape [1 ] == input_ids .shape [1 ] - 1 :
1077
- # For offline training, labels may be 1 token shorter than input_ids.
1078
- # We will just pad a 0 to the labels to make the seq_len the same as
1079
- # input_ids. This will introduce a small error in training if logit_distillation
1080
- # is False, and testing accuracy is wrong for the last token.
1081
- right_token_pad = torch .zeros (
1082
- (labels .shape [0 ], 1 ),
1083
- dtype = labels .dtype ,
1084
- device = labels .device ,
1085
- )
1086
- labels = torch .cat ((labels , right_token_pad ), dim = - 1 )
1087
-
1088
- # If eagle_freeze_base_model is set to True,
1089
- # the base model is frozen .
1090
- loss = self .compute_language_model_loss (labels , logits_sbh )
1091
- loss = 0.0 * loss
1092
-
1093
- for i in range (self .eagle_config .parallel_draft_step ):
1094
- eagle_logits = eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1095
- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1096
- loss_ = loss_ [:, i :]
1097
- loss [:, i + 1 :] += self .eagle_loss_decay_factor * loss_
1098
-
1099
- if self .eagle_report_acc and not self .training :
1100
- acc = []
1101
- with torch .no_grad ():
1102
- for i in range (self .eagle_config .parallel_draft_step ):
1103
- gathered_logits = gather_from_tensor_model_parallel_region (
1104
- eagle_logits_0 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1105
- )
1106
- gathered_logits = gathered_logits [i :- 1 ]
1107
- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1108
- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1109
- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1110
- top1_p = torch .eq (labels [:, i + 1 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1111
- acc .append (top1_p )
1112
-
1113
- if get_tensor_model_parallel_rank () == 0 :
1114
- print (
1115
- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 1st Top-1: { acc } " ,
1116
- flush = True ,
1117
- )
1118
-
1119
- # Second round of EAGLE loss
1120
- eagle_logits_1 = []
1121
- for i in range (self .eagle_config .parallel_draft_step ):
1122
- eagle_inputs_1 = self ._get_eagle_module_inputs (
1123
- input_ids = input_ids ,
1124
- hidden_states = eagle_module_input_hidden_states ,
1125
- attention_mask = attention_mask ,
1126
- position_ids = position_ids ,
1127
- features = eagle_hidden_states_0_pre_norm ,
1128
- ttt_step = 2 ,
1129
- parallel_draft_step = i + 1 ,
1130
- )
1090
+ eagle_logits .append (eagle_logits_ )
1091
+ eagle_logits = torch .cat (eagle_logits , dim = 0 )
1092
+ eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_
1131
1093
1132
- _ , eagle_logits_ , eagle_hidden_states_2x_pre_norm = self ._eagle_forward (
1133
- eagle_inputs_1 ,
1134
- output_weight ,
1135
- inference_params = inference_params ,
1136
- packed_seq_params = packed_seq_params ,
1137
- inference_context = eagle_inference_context ,
1138
- ** (extra_block_kwargs or {}),
1139
- )
1140
- eagle_logits_1 .append (eagle_logits_ )
1141
- eagle_logits_1 = torch .cat (eagle_logits_1 , dim = 0 )
1142
-
1143
- for i in range (self .eagle_config .parallel_draft_step ):
1144
- eagle_logits = eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1145
- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1146
- loss_ = loss_ [:, i + 1 :]
1147
- loss [:, i + 2 :] += self .eagle_loss_decay_factor ** 2 * loss_
1148
-
1149
- if self .eagle_report_acc and not self .training :
1150
- acc = []
1151
- with torch .no_grad ():
1152
- for i in range (self .eagle_config .parallel_draft_step ):
1153
- gathered_logits = gather_from_tensor_model_parallel_region (
1154
- eagle_logits_1 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1155
- )
1156
- gathered_logits = gathered_logits [i + 1 : - 1 ]
1157
- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1158
- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1159
- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1160
- top1_p = torch .eq (labels [:, i + 2 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1161
- acc .append (top1_p )
1162
-
1163
- if get_tensor_model_parallel_rank () == 0 :
1164
- print (
1165
- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 2nd Top-1: { acc } " ,
1166
- flush = True ,
1167
- )
1094
+ # If labels are not provided, return the original logits. We only return after
1095
+ # all eagle weights have been exercised for quantization calibration purpose.
1096
+ if labels is None :
1097
+ return logits_sbh .transpose (0 , 1 ).contiguous ()
1168
1098
1169
- # Third EAGLE loss
1170
- eagle_logits_2 = []
1171
- for i in range (self .eagle_config .parallel_draft_step ):
1172
- eagle_inputs_2 = self ._get_eagle_module_inputs (
1173
- input_ids = input_ids ,
1174
- hidden_states = eagle_module_input_hidden_states ,
1175
- attention_mask = attention_mask ,
1176
- position_ids = position_ids ,
1177
- features = eagle_hidden_states_2x_pre_norm ,
1178
- ttt_step = 3 ,
1179
- parallel_draft_step = i + 1 ,
1180
- )
1181
-
1182
- _ , eagle_logits_ , eagle_hidden_states_3x_pre_norm = self ._eagle_forward (
1183
- eagle_inputs_2 ,
1184
- output_weight ,
1185
- inference_params = inference_params ,
1186
- packed_seq_params = packed_seq_params ,
1187
- inference_context = eagle_inference_context ,
1188
- ** (extra_block_kwargs or {}),
1189
- )
1190
- eagle_logits_2 .append (eagle_logits_ )
1191
- eagle_logits_2 = torch .cat (eagle_logits_2 , dim = 0 )
1192
-
1193
- for i in range (self .eagle_config .parallel_draft_step ):
1194
- eagle_logits = eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1195
- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1196
- loss_ = loss_ [:, i + 2 :]
1197
- loss [:, i + 3 :] += self .eagle_loss_decay_factor ** 3 * loss_
1198
-
1199
- if self .eagle_report_acc and not self .training :
1200
- acc = []
1201
- with torch .no_grad ():
1202
- for i in range (self .eagle_config .parallel_draft_step ):
1203
- gathered_logits = gather_from_tensor_model_parallel_region (
1204
- eagle_logits_2 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1205
- )
1206
- gathered_logits = gathered_logits [i + 2 : - 1 ]
1207
- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1208
- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1209
- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1210
- top1_p = torch .eq (labels [:, i + 3 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1211
- acc .append (top1_p )
1212
-
1213
- if get_tensor_model_parallel_rank () == 0 :
1214
- print (
1215
- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 3rd Top-1: { acc } " ,
1216
- flush = True ,
1099
+ for i in range (self .eagle_config .parallel_draft_step ):
1100
+ eagle_logit = eagle_logits [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1101
+ loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logit )
1102
+ loss_ = loss_ [:, i + ttt_step :]
1103
+ loss [:, i + ttt_step + 1 :] += (
1104
+ self .eagle_loss_decay_factor ** (ttt_step + i ) * loss_
1217
1105
)
1218
1106
1219
- # Forth EAGLE loss
1220
- eagle_logits_3 = []
1221
- for i in range (self .eagle_config .parallel_draft_step ):
1222
- eagle_inputs_3 = self ._get_eagle_module_inputs (
1223
- input_ids = input_ids ,
1224
- hidden_states = eagle_module_input_hidden_states ,
1225
- attention_mask = attention_mask ,
1226
- position_ids = position_ids ,
1227
- features = eagle_hidden_states_3x_pre_norm ,
1228
- ttt_step = 4 ,
1229
- parallel_draft_step = i + 1 ,
1230
- )
1231
-
1232
- _ , eagle_logits_ , eagle_hidden_states_4x_pre_norm = self ._eagle_forward (
1233
- eagle_inputs_3 ,
1234
- output_weight ,
1235
- inference_params = inference_params ,
1236
- packed_seq_params = packed_seq_params ,
1237
- inference_context = eagle_inference_context ,
1238
- ** (extra_block_kwargs or {}),
1239
- )
1240
- eagle_logits_3 .append (eagle_logits_ )
1241
- eagle_logits_3 = torch .cat (eagle_logits_3 , dim = 0 )
1242
-
1243
- for i in range (self .eagle_config .parallel_draft_step ):
1244
- eagle_logits = eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1245
- loss_ = self ._compute_eagle_loss (logits_sbh , labels , eagle_logits )
1246
- loss_ = loss_ [:, i + 3 :]
1247
- loss [:, i + 4 :] += self .eagle_loss_decay_factor ** 4 * loss_
1248
-
1249
- if self .eagle_report_acc and not self .training :
1250
- acc = []
1251
- with torch .no_grad ():
1252
- for i in range (self .eagle_config .parallel_draft_step ):
1253
- gathered_logits = gather_from_tensor_model_parallel_region (
1254
- eagle_logits_3 [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1107
+ if self .eagle_report_acc and not self .training :
1108
+ acc = []
1109
+ with torch .no_grad ():
1110
+ for i in range (self .eagle_config .parallel_draft_step ):
1111
+ gathered_logits = gather_from_tensor_model_parallel_region (
1112
+ eagle_logits [i * input_ids .shape [1 ] : (i + 1 ) * input_ids .shape [1 ]]
1113
+ )
1114
+ gathered_logits = gathered_logits [i + ttt_step : - 1 ]
1115
+ eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1116
+ if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1117
+ eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1118
+ top1_p = (
1119
+ torch .eq (labels [:, i + ttt_step + 1 :], eagle_top1 ).sum ()
1120
+ / eagle_top1 .numel ()
1121
+ )
1122
+ acc .append (top1_p )
1123
+
1124
+ if get_tensor_model_parallel_rank () == 0 :
1125
+ print (
1126
+ f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} "
1127
+ f"EAGLE 1st Top-1: { acc } " ,
1128
+ flush = True ,
1255
1129
)
1256
- gathered_logits = gathered_logits [i + 3 : - 1 ]
1257
- eagle_top1 = gathered_logits .transpose (0 , 1 ).argmax (dim = - 1 )
1258
- if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1259
- eagle_top1 += self .eagle_module .d2t [eagle_top1 ]
1260
- top1_p = torch .eq (labels [:, i + 4 :], eagle_top1 ).sum () / eagle_top1 .numel ()
1261
- acc .append (top1_p )
1262
-
1263
- if get_tensor_model_parallel_rank () == 0 :
1264
- print (
1265
- f"{ torch .distributed .get_rank ():3} /{ torch .distributed .get_world_size ():3} EAGLE 4th Top-1: { acc } " ,
1266
- flush = True ,
1267
- )
1268
1130
1269
1131
return loss
1270
1132
0 commit comments