1717
1818os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
1919from abc import abstractmethod
20+ from contextlib import nullcontext
2021
2122import torch
2223import torch .distributed as dist
2324from torch .distributed .device_mesh import DeviceMesh
2425from tqdm import tqdm
2526from transformers import AutoModelForCausalLM
2627from transformers .optimization import get_linear_schedule_with_warmup
28+ from transformers .utils import ModelOutput
2729
2830import modelopt .torch .opt as mto
2931import modelopt .torch .speculative as mtsp
3032from modelopt .torch .speculative .config import EAGLE3_DEFAULT_CFG
3133
34+ try :
35+ import wandb
36+ except ImportError :
37+ wandb = None
38+
39+
3240mto .enable_huggingface_checkpointing ()
3341
3442# Hyperparameters for profiling
@@ -51,12 +59,13 @@ class BaseDistillTrainer:
5159 student_step: student step function.
5260 """
5361
54- def __init__ (self , rank , args , tokenizer ):
62+ def __init__ (self , rank , args , tokenizer , dataloader ):
5563 self .rank = rank
5664 args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
5765 args .student_pgroup = dist .new_group (ranks = args .student_ranks )
5866 self .args = args
5967 self .tokenizer = tokenizer
68+ self .dataloader = dataloader
6069 if rank in args .student_ranks :
6170 self .model = self .prepare_student_model ()
6271 self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
@@ -71,46 +80,49 @@ def _print_model_placement(self, module):
7180 for name , param in module .named_parameters ():
7281 print (f"(Rank { self .rank } ) { name } ---> { param .device } " )
7382
74- @property
75- def current_rank_device (self ):
76- pass
77-
78- @property
79- def distill_metadata (self ):
80- pass
81-
8283 def _reset_all_mem_stats (self ):
8384 torch .cuda .reset_max_memory_allocated (self .current_rank_device )
8485
8586 def _print_mem_stats (self ):
8687 max_mem = torch .cuda .max_memory_allocated (self .current_rank_device )
8788 print (f"GPU { self .current_rank_device } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
8889
90+ @property
91+ def current_rank_device (self ):
92+ """Return device of the current rank."""
93+
94+ @property
95+ def distill_metadata (self ):
96+ """Return a DistillMetadata that describe the distillation message received by student."""
97+
8998 @abstractmethod
90- def load_teacher_model (self ):
91- pass
99+ def prepare_teacher_model (self ):
100+ """Return coverted teacher model with correct parallelization."""
92101
93102 @abstractmethod
94- def load_student_model (self ):
95- pass
103+ def prepare_student_model (self ):
104+ """Return coverted student model with correct parallelization."""
96105
97106 @abstractmethod
98- def teacher_step (self , * args , ** kwargs ) -> dict [str , torch .Tensor ]:
99- pass
107+ def teacher_step (self , * args , ** kwargs ) -> list [ dict [str , torch .Tensor ] ]:
108+ """Run one student step and return distillation messages for each student rank."""
100109
101110 @abstractmethod
102- def student_step (self , * args , ** kwargs ):
103- pass
111+ def student_step (self , * args , ** kwargs ) -> ModelOutput :
112+ """Run forward of student step, return a modeloutput object."""
104113
105- def save_pretrained (self , path = None ):
114+ def save_pretrained (self , save_path ):
115+ """Save the model and tokenizer."""
106116 if self .rank == self .args .student_ranks [0 ]:
107- path = self .args .out_path if path is None else path
108- self .model .save_pretrained (path )
109- self .tokenizer .save_pretrained (path )
110- print (f"Pretrained model saved to { path } " )
117+ if isinstance (self .model , torch .nn .parallel .DistributedDataParallel ):
118+ self .model .module .save_pretrained (save_path )
119+ else :
120+ self .model .save_pretrained (save_path )
121+ self .tokenizer .save_pretrained (save_path )
122+ print (f"Pretrained model saved to { save_path } " )
111123
112124 def _check_valid_message (self , message : dict [str , torch .Tensor ]):
113- # Check if keys and length match between message and distill_metadata
125+ """ Check if message in the format of distill_metadata."""
114126 if set (message .keys ()) != set (self .distill_metadata .keys ()):
115127 raise ValueError (
116128 f"Message keys: { set (message .keys ())} \n "
@@ -142,8 +154,8 @@ def _recv_from_teacher(self):
142154 for req in reqs :
143155 req .wait ()
144156
145- def _get_distill_kwargs (self ):
146- """Return a copy of received buffer for student training ."""
157+ def _clone_recv_buffer (self ):
158+ """Return a copy of received tensors for student step input ."""
147159 return {k : v .clone ().detach () for k , v in self .student_recv_buffer .items ()}
148160
149161 def _send_to_student (self , teacher_outputs ):
@@ -160,49 +172,63 @@ def _send_to_student(self, teacher_outputs):
160172 for req in reqs :
161173 req .wait ()
162174
163- def train (self , dataloader ):
175+ def _get_logging_context (self ):
176+ print (
177+ f"Rank { self .rank } is logging: { wandb is not None and self .rank == self .args .student_ranks [0 ]} "
178+ )
179+ if wandb is not None and self .rank == self .args .student_ranks [0 ]:
180+ return wandb .init (
181+ entity = os .environ ["WANDB_ENTITY" ],
182+ project = os .environ ["WANDB_PROJECT" ],
183+ config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
184+ )
185+ return nullcontext ()
186+
187+ def train (self ):
164188 """Main training entrance of the composed model."""
165189 self ._reset_all_mem_stats ()
166190
167191 if self .rank in self .args .student_ranks :
168- import wandb
169-
170- wandb .login ()
171-
172- with wandb .init (
173- entity = os .environ ["WANDB_ENTITY" ],
174- project = os .environ ["WANDB_PROJECT" ],
175- config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
176- ) as run :
192+ with self ._get_logging_context () as run :
177193 self ._init_student_recv_buffer ()
178- wandb .watch (self .model , log = "all" )
179194
195+ # Student training loop
180196 for epoch in range (EPOCHS ):
181197 pbar = (
182- tqdm (dataloader ) if self .rank == self .args .student_ranks [0 ] else dataloader
198+ tqdm (self .dataloader )
199+ if self .rank == self .args .student_ranks [0 ]
200+ else self .dataloader
183201 )
184202 for i , batch in enumerate (pbar ):
185- global_step = epoch * len (dataloader ) + i
203+ global_step = epoch * len (self . dataloader ) + i
186204 inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
205+
206+ # Receive distill messages from teacher
187207 self ._recv_from_teacher ()
188- loss , train_acc = self .student_step (inputs , ** self ._get_distill_kwargs ())
189208
209+ # Run forward of student step
210+ output = self .student_step (inputs , ** self ._clone_recv_buffer ())
211+ loss = output .loss
212+
213+ # Run backward step
214+ loss .backward ()
215+ self .optimizer .step ()
216+ self .scheduler .step ()
217+
218+ # Log and save only on student rank 0
190219 if self .rank != self .args .student_ranks [0 ]:
191220 continue
192221
193- pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
222+ train_metrics = {
223+ "loss" : round (loss .item (), 3 ),
224+ "lr" : self .optimizer .param_groups [0 ]["lr" ],
225+ # Attach all float metrics
226+ ** {k : round (v , 3 ) for k , v in output .items () if isinstance (v , float )},
227+ }
228+
229+ pbar .set_description (f"Epoch { epoch } Loss { train_metrics ['loss' ]} " )
194230 if global_step % LOG_INTERVAL == 0 :
195- run .log (
196- {
197- "loss" : loss ,
198- "train_acc_step0" : train_acc [0 ],
199- "train_acc_step1" : train_acc [1 ],
200- "train_acc_step2" : train_acc [2 ],
201- "train_acc_step3" : train_acc [3 ],
202- "lr" : self .optimizer .param_groups [0 ]["lr" ],
203- },
204- step = global_step ,
205- )
231+ run .log (train_metrics , step = global_step )
206232 if global_step > 0 and global_step % SAVE_INTERVAL == 0 :
207233 self .save_pretrained (
208234 f"{ self .args .out_path } /epoch_{ epoch } _step_{ global_step } "
@@ -211,13 +237,10 @@ def train(self, dataloader):
211237 else :
212238 # Inference Loop
213239 for epoch in range (EPOCHS ):
214- for i , batch in enumerate (dataloader ):
215- global_step = epoch * len (dataloader ) + i
240+ for i , batch in enumerate (self .dataloader ):
216241 inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
217- inputs ["position_ids" ] = None
218242 with torch .inference_mode ():
219- teacher_outputs = self .teacher_step (self .model , inputs )
220- self ._send_to_student (teacher_outputs )
243+ self ._send_to_student (self .teacher_step (self .model , inputs ))
221244
222245 self ._print_mem_stats ()
223246 # Makesure all processes finished before destroy.
@@ -227,14 +250,15 @@ def train(self, dataloader):
227250
228251
229252class EagleTPTrainer (BaseDistillTrainer ):
230- def __init__ (self , rank , args , tokenizer ):
253+ def __init__ (self , rank , args , tokenizer , dataloader ):
254+ # Load eagle config
231255 args .eagle_config = EAGLE3_DEFAULT_CFG ["config" ]
232256 if args .eagle_config_path :
233257 with open (args .eagle_config_path ) as f :
234258 custom_config = json .load (f )
235259 args .eagle_config ["eagle_architecture_config" ].update (custom_config )
236260
237- super ().__init__ (rank , args , tokenizer )
261+ super ().__init__ (rank , args , tokenizer , dataloader )
238262
239263 @property
240264 def current_rank_device (self ):
@@ -245,6 +269,7 @@ def current_rank_device(self):
245269
246270 @property
247271 def distill_metadata (self ) -> DistillMetadata :
272+ """Description of the distillation signal received by student."""
248273 return {
249274 "base_model_hidden_states" : (
250275 torch .Size (
@@ -279,12 +304,14 @@ def distill_metadata(self) -> DistillMetadata:
279304 }
280305
281306 def prepare_teacher_model (self ):
307+ # Load model with TP among teacher ranks.
282308 model = AutoModelForCausalLM .from_pretrained (
283309 self .args .model_path ,
284310 torch_dtype = "auto" ,
285311 tp_plan = "auto" ,
286312 device_mesh = DeviceMesh .from_group (self .args .teacher_pgroup , "cuda" ),
287313 )
314+ # load eagle config and convert.
288315 self .args .eagle_config ["eagle_architecture_config" ].update (
289316 {
290317 "hidden_size" : model .config .hidden_size ,
@@ -298,7 +325,6 @@ def prepare_teacher_model(self):
298325 return model
299326
300327 def prepare_student_model (self ):
301- """Load student model on a single device and keep needed modules from teacher."""
302328 # Load to CPU first to avoid OOM
303329 model = AutoModelForCausalLM .from_pretrained (
304330 self .args .model_path , torch_dtype = "auto" , device_map = "cpu"
@@ -331,15 +357,19 @@ def prepare_student_model(self):
331357 return model
332358
333359 def teacher_step (self , model , inputs ):
360+ # Collect base model outputs.
334361 base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
335362 ** inputs ,
336363 freeze_base_model = True ,
337364 past_key_values = None ,
338365 )
339- # aux_hidden_states could be on multiple devices. Gather them and cat.
366+
367+ # Aux_hidden_states could be on multiple devices. Gather before cat.
340368 aux_hidden_states = torch .cat (
341369 [t .to (base_model_logits .device ) for t in model .pop_aux_hidden_states ()], dim = - 1
342370 )
371+
372+ # Chunk the tensors for each student rank.
343373 base_model_hidden_states = base_model_hidden_states .chunk (len (self .args .student_ranks ))
344374 base_model_logits = base_model_logits .chunk (len (self .args .student_ranks ))
345375 aux_hidden_states = aux_hidden_states .chunk (len (self .args .student_ranks ))
@@ -356,28 +386,12 @@ def teacher_step(self, model, inputs):
356386 def student_step (
357387 self ,
358388 inputs ,
359- base_model_hidden_states ,
360- aux_hidden_states ,
361- base_model_logits ,
362- ):
389+ ** distill_msgs ,
390+ ) -> ModelOutput :
363391 self .optimizer .zero_grad ()
364- # Second stage forward using the unified model
392+
393+ # Chunk inputs for each student rank.
365394 inputs = {k : v .chunk (len (self .args .student_ranks ))[self .rank ] for k , v in inputs .items ()}
366- output = self .model (
367- ** inputs ,
368- # providing base model outputs to bypass the base model forward.
369- base_model_outputs = {
370- "base_model_hidden_states" : base_model_hidden_states ,
371- "aux_hidden_states" : aux_hidden_states .clone ().detach (),
372- "base_model_logits" : base_model_logits .clone ().detach (),
373- },
374- )
375- loss = output .loss
376- # print(f"Rank {self.rank} loss: {loss.item()}")
377- train_acc = output .train_acc
378-
379- # Backward
380- loss .backward ()
381- self .optimizer .step ()
382- self .scheduler .step ()
383- return round (loss .item (), 3 ), train_acc
395+
396+ # Second stage forward with provided base model outputs.
397+ return self .model (** inputs , base_model_outputs = distill_msgs )
0 commit comments