17
17
18
18
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
19
19
from abc import abstractmethod
20
+ from contextlib import nullcontext
20
21
21
22
import torch
22
23
import torch .distributed as dist
23
24
from torch .distributed .device_mesh import DeviceMesh
24
25
from tqdm import tqdm
25
26
from transformers import AutoModelForCausalLM
26
27
from transformers .optimization import get_linear_schedule_with_warmup
28
+ from transformers .utils import ModelOutput
27
29
28
30
import modelopt .torch .opt as mto
29
31
import modelopt .torch .speculative as mtsp
30
32
from modelopt .torch .speculative .config import EAGLE3_DEFAULT_CFG
31
33
34
+ try :
35
+ import wandb
36
+ except ImportError :
37
+ wandb = None
38
+
39
+
32
40
mto .enable_huggingface_checkpointing ()
33
41
34
42
# Hyperparameters for profiling
@@ -51,12 +59,13 @@ class BaseDistillTrainer:
51
59
student_step: student step function.
52
60
"""
53
61
54
- def __init__ (self , rank , args , tokenizer ):
62
+ def __init__ (self , rank , args , tokenizer , dataloader ):
55
63
self .rank = rank
56
64
args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
57
65
args .student_pgroup = dist .new_group (ranks = args .student_ranks )
58
66
self .args = args
59
67
self .tokenizer = tokenizer
68
+ self .dataloader = dataloader
60
69
if rank in args .student_ranks :
61
70
self .model = self .prepare_student_model ()
62
71
self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
@@ -71,46 +80,49 @@ def _print_model_placement(self, module):
71
80
for name , param in module .named_parameters ():
72
81
print (f"(Rank { self .rank } ) { name } ---> { param .device } " )
73
82
74
- @property
75
- def current_rank_device (self ):
76
- pass
77
-
78
- @property
79
- def distill_metadata (self ):
80
- pass
81
-
82
83
def _reset_all_mem_stats (self ):
83
84
torch .cuda .reset_max_memory_allocated (self .current_rank_device )
84
85
85
86
def _print_mem_stats (self ):
86
87
max_mem = torch .cuda .max_memory_allocated (self .current_rank_device )
87
88
print (f"GPU { self .current_rank_device } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
88
89
90
+ @property
91
+ def current_rank_device (self ):
92
+ """Return device of the current rank."""
93
+
94
+ @property
95
+ def distill_metadata (self ):
96
+ """Return a DistillMetadata that describe the distillation message received by student."""
97
+
89
98
@abstractmethod
90
- def load_teacher_model (self ):
91
- pass
99
+ def prepare_teacher_model (self ):
100
+ """Return coverted teacher model with correct parallelization."""
92
101
93
102
@abstractmethod
94
- def load_student_model (self ):
95
- pass
103
+ def prepare_student_model (self ):
104
+ """Return coverted student model with correct parallelization."""
96
105
97
106
@abstractmethod
98
- def teacher_step (self , * args , ** kwargs ) -> dict [str , torch .Tensor ]:
99
- pass
107
+ def teacher_step (self , * args , ** kwargs ) -> list [ dict [str , torch .Tensor ] ]:
108
+ """Run one student step and return distillation messages for each student rank."""
100
109
101
110
@abstractmethod
102
- def student_step (self , * args , ** kwargs ):
103
- pass
111
+ def student_step (self , * args , ** kwargs ) -> ModelOutput :
112
+ """Run forward of student step, return a modeloutput object."""
104
113
105
- def save_pretrained (self , path = None ):
114
+ def save_pretrained (self , save_path ):
115
+ """Save the model and tokenizer."""
106
116
if self .rank == self .args .student_ranks [0 ]:
107
- path = self .args .out_path if path is None else path
108
- self .model .save_pretrained (path )
109
- self .tokenizer .save_pretrained (path )
110
- print (f"Pretrained model saved to { path } " )
117
+ if isinstance (self .model , torch .nn .parallel .DistributedDataParallel ):
118
+ self .model .module .save_pretrained (save_path )
119
+ else :
120
+ self .model .save_pretrained (save_path )
121
+ self .tokenizer .save_pretrained (save_path )
122
+ print (f"Pretrained model saved to { save_path } " )
111
123
112
124
def _check_valid_message (self , message : dict [str , torch .Tensor ]):
113
- # Check if keys and length match between message and distill_metadata
125
+ """ Check if message in the format of distill_metadata."""
114
126
if set (message .keys ()) != set (self .distill_metadata .keys ()):
115
127
raise ValueError (
116
128
f"Message keys: { set (message .keys ())} \n "
@@ -142,8 +154,8 @@ def _recv_from_teacher(self):
142
154
for req in reqs :
143
155
req .wait ()
144
156
145
- def _get_distill_kwargs (self ):
146
- """Return a copy of received buffer for student training ."""
157
+ def _clone_recv_buffer (self ):
158
+ """Return a copy of received tensors for student step input ."""
147
159
return {k : v .clone ().detach () for k , v in self .student_recv_buffer .items ()}
148
160
149
161
def _send_to_student (self , teacher_outputs ):
@@ -160,49 +172,63 @@ def _send_to_student(self, teacher_outputs):
160
172
for req in reqs :
161
173
req .wait ()
162
174
163
- def train (self , dataloader ):
175
+ def _get_logging_context (self ):
176
+ print (
177
+ f"Rank { self .rank } is logging: { wandb is not None and self .rank == self .args .student_ranks [0 ]} "
178
+ )
179
+ if wandb is not None and self .rank == self .args .student_ranks [0 ]:
180
+ return wandb .init (
181
+ entity = os .environ ["WANDB_ENTITY" ],
182
+ project = os .environ ["WANDB_PROJECT" ],
183
+ config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
184
+ )
185
+ return nullcontext ()
186
+
187
+ def train (self ):
164
188
"""Main training entrance of the composed model."""
165
189
self ._reset_all_mem_stats ()
166
190
167
191
if self .rank in self .args .student_ranks :
168
- import wandb
169
-
170
- wandb .login ()
171
-
172
- with wandb .init (
173
- entity = os .environ ["WANDB_ENTITY" ],
174
- project = os .environ ["WANDB_PROJECT" ],
175
- config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
176
- ) as run :
192
+ with self ._get_logging_context () as run :
177
193
self ._init_student_recv_buffer ()
178
- wandb .watch (self .model , log = "all" )
179
194
195
+ # Student training loop
180
196
for epoch in range (EPOCHS ):
181
197
pbar = (
182
- tqdm (dataloader ) if self .rank == self .args .student_ranks [0 ] else dataloader
198
+ tqdm (self .dataloader )
199
+ if self .rank == self .args .student_ranks [0 ]
200
+ else self .dataloader
183
201
)
184
202
for i , batch in enumerate (pbar ):
185
- global_step = epoch * len (dataloader ) + i
203
+ global_step = epoch * len (self . dataloader ) + i
186
204
inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
205
+
206
+ # Receive distill messages from teacher
187
207
self ._recv_from_teacher ()
188
- loss , train_acc = self .student_step (inputs , ** self ._get_distill_kwargs ())
189
208
209
+ # Run forward of student step
210
+ output = self .student_step (inputs , ** self ._clone_recv_buffer ())
211
+ loss = output .loss
212
+
213
+ # Run backward step
214
+ loss .backward ()
215
+ self .optimizer .step ()
216
+ self .scheduler .step ()
217
+
218
+ # Log and save only on student rank 0
190
219
if self .rank != self .args .student_ranks [0 ]:
191
220
continue
192
221
193
- pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
222
+ train_metrics = {
223
+ "loss" : round (loss .item (), 3 ),
224
+ "lr" : self .optimizer .param_groups [0 ]["lr" ],
225
+ # Attach all float metrics
226
+ ** {k : round (v , 3 ) for k , v in output .items () if isinstance (v , float )},
227
+ }
228
+
229
+ pbar .set_description (f"Epoch { epoch } Loss { train_metrics ['loss' ]} " )
194
230
if global_step % LOG_INTERVAL == 0 :
195
- run .log (
196
- {
197
- "loss" : loss ,
198
- "train_acc_step0" : train_acc [0 ],
199
- "train_acc_step1" : train_acc [1 ],
200
- "train_acc_step2" : train_acc [2 ],
201
- "train_acc_step3" : train_acc [3 ],
202
- "lr" : self .optimizer .param_groups [0 ]["lr" ],
203
- },
204
- step = global_step ,
205
- )
231
+ run .log (train_metrics , step = global_step )
206
232
if global_step > 0 and global_step % SAVE_INTERVAL == 0 :
207
233
self .save_pretrained (
208
234
f"{ self .args .out_path } /epoch_{ epoch } _step_{ global_step } "
@@ -211,13 +237,10 @@ def train(self, dataloader):
211
237
else :
212
238
# Inference Loop
213
239
for epoch in range (EPOCHS ):
214
- for i , batch in enumerate (dataloader ):
215
- global_step = epoch * len (dataloader ) + i
240
+ for i , batch in enumerate (self .dataloader ):
216
241
inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
217
- inputs ["position_ids" ] = None
218
242
with torch .inference_mode ():
219
- teacher_outputs = self .teacher_step (self .model , inputs )
220
- self ._send_to_student (teacher_outputs )
243
+ self ._send_to_student (self .teacher_step (self .model , inputs ))
221
244
222
245
self ._print_mem_stats ()
223
246
# Makesure all processes finished before destroy.
@@ -227,14 +250,15 @@ def train(self, dataloader):
227
250
228
251
229
252
class EagleTPTrainer (BaseDistillTrainer ):
230
- def __init__ (self , rank , args , tokenizer ):
253
+ def __init__ (self , rank , args , tokenizer , dataloader ):
254
+ # Load eagle config
231
255
args .eagle_config = EAGLE3_DEFAULT_CFG ["config" ]
232
256
if args .eagle_config_path :
233
257
with open (args .eagle_config_path ) as f :
234
258
custom_config = json .load (f )
235
259
args .eagle_config ["eagle_architecture_config" ].update (custom_config )
236
260
237
- super ().__init__ (rank , args , tokenizer )
261
+ super ().__init__ (rank , args , tokenizer , dataloader )
238
262
239
263
@property
240
264
def current_rank_device (self ):
@@ -245,6 +269,7 @@ def current_rank_device(self):
245
269
246
270
@property
247
271
def distill_metadata (self ) -> DistillMetadata :
272
+ """Description of the distillation signal received by student."""
248
273
return {
249
274
"base_model_hidden_states" : (
250
275
torch .Size (
@@ -279,12 +304,14 @@ def distill_metadata(self) -> DistillMetadata:
279
304
}
280
305
281
306
def prepare_teacher_model (self ):
307
+ # Load model with TP among teacher ranks.
282
308
model = AutoModelForCausalLM .from_pretrained (
283
309
self .args .model_path ,
284
310
torch_dtype = "auto" ,
285
311
tp_plan = "auto" ,
286
312
device_mesh = DeviceMesh .from_group (self .args .teacher_pgroup , "cuda" ),
287
313
)
314
+ # load eagle config and convert.
288
315
self .args .eagle_config ["eagle_architecture_config" ].update (
289
316
{
290
317
"hidden_size" : model .config .hidden_size ,
@@ -298,7 +325,6 @@ def prepare_teacher_model(self):
298
325
return model
299
326
300
327
def prepare_student_model (self ):
301
- """Load student model on a single device and keep needed modules from teacher."""
302
328
# Load to CPU first to avoid OOM
303
329
model = AutoModelForCausalLM .from_pretrained (
304
330
self .args .model_path , torch_dtype = "auto" , device_map = "cpu"
@@ -331,15 +357,19 @@ def prepare_student_model(self):
331
357
return model
332
358
333
359
def teacher_step (self , model , inputs ):
360
+ # Collect base model outputs.
334
361
base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
335
362
** inputs ,
336
363
freeze_base_model = True ,
337
364
past_key_values = None ,
338
365
)
339
- # aux_hidden_states could be on multiple devices. Gather them and cat.
366
+
367
+ # Aux_hidden_states could be on multiple devices. Gather before cat.
340
368
aux_hidden_states = torch .cat (
341
369
[t .to (base_model_logits .device ) for t in model .pop_aux_hidden_states ()], dim = - 1
342
370
)
371
+
372
+ # Chunk the tensors for each student rank.
343
373
base_model_hidden_states = base_model_hidden_states .chunk (len (self .args .student_ranks ))
344
374
base_model_logits = base_model_logits .chunk (len (self .args .student_ranks ))
345
375
aux_hidden_states = aux_hidden_states .chunk (len (self .args .student_ranks ))
@@ -356,28 +386,12 @@ def teacher_step(self, model, inputs):
356
386
def student_step (
357
387
self ,
358
388
inputs ,
359
- base_model_hidden_states ,
360
- aux_hidden_states ,
361
- base_model_logits ,
362
- ):
389
+ ** distill_msgs ,
390
+ ) -> ModelOutput :
363
391
self .optimizer .zero_grad ()
364
- # Second stage forward using the unified model
392
+
393
+ # Chunk inputs for each student rank.
365
394
inputs = {k : v .chunk (len (self .args .student_ranks ))[self .rank ] for k , v in inputs .items ()}
366
- output = self .model (
367
- ** inputs ,
368
- # providing base model outputs to bypass the base model forward.
369
- base_model_outputs = {
370
- "base_model_hidden_states" : base_model_hidden_states ,
371
- "aux_hidden_states" : aux_hidden_states .clone ().detach (),
372
- "base_model_logits" : base_model_logits .clone ().detach (),
373
- },
374
- )
375
- loss = output .loss
376
- # print(f"Rank {self.rank} loss: {loss.item()}")
377
- train_acc = output .train_acc
378
-
379
- # Backward
380
- loss .backward ()
381
- self .optimizer .step ()
382
- self .scheduler .step ()
383
- return round (loss .item (), 3 ), train_acc
395
+
396
+ # Second stage forward with provided base model outputs.
397
+ return self .model (** inputs , base_model_outputs = distill_msgs )
0 commit comments