@@ -66,13 +66,13 @@ def __init__(self, rank, args, tokenizer, dataloader):
6666
6767 # Prepare models
6868 if rank in args .student_ranks :
69- self .model = self .prepare_student_model ()
69+ self .model = self ._prepare_student_model ()
7070 self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
7171 self .scheduler = get_linear_schedule_with_warmup (
7272 self .optimizer , num_warmup_steps = 0 , num_training_steps = 117380
7373 )
7474 else :
75- self .model = self .prepare_teacher_model ()
75+ self .model = self ._prepare_teacher_model ()
7676 self ._print_model_placement (self .model )
7777
7878 def _print_model_placement (self , module ):
@@ -95,11 +95,11 @@ def distill_metadata(self):
9595 """Return a DistillMetadata that describe the distillation message received by student."""
9696
9797 @abstractmethod
98- def prepare_teacher_model (self ):
98+ def _prepare_teacher_model (self ):
9999 """Return coverted teacher model with correct parallelization."""
100100
101101 @abstractmethod
102- def prepare_student_model (self ):
102+ def _prepare_student_model (self ):
103103 """Return coverted student model with correct parallelization."""
104104
105105 @abstractmethod
@@ -272,43 +272,7 @@ def current_rank_device(self):
272272 else :
273273 return self .args .teacher_devices [self .rank - len (self .args .student_ranks )]
274274
275- @property
276- def distill_metadata (self ) -> DistillMetadata :
277- """Description of the distillation signal received by student."""
278- return {
279- "base_model_hidden_states" : (
280- torch .Size (
281- [
282- int (self .args .batch_size / len (self .args .student_ranks )),
283- self .args .training_seq_len ,
284- 2048 ,
285- ]
286- ),
287- torch .bfloat16 ,
288- ),
289- "aux_hidden_states" : (
290- torch .Size (
291- [
292- int (self .args .batch_size / len (self .args .student_ranks )),
293- self .args .training_seq_len ,
294- 2048 * 3 ,
295- ]
296- ),
297- torch .bfloat16 ,
298- ),
299- "base_model_logits" : (
300- torch .Size (
301- [
302- int (self .args .batch_size / len (self .args .student_ranks )),
303- self .args .training_seq_len ,
304- self .args .draft_vocab_size ,
305- ]
306- ),
307- torch .bfloat16 ,
308- ),
309- }
310-
311- def prepare_teacher_model (self ):
275+ def _prepare_teacher_model (self ):
312276 # Load model with TP among teacher ranks.
313277 model = AutoModelForCausalLM .from_pretrained (
314278 self .args .model_path ,
@@ -324,12 +288,11 @@ def prepare_teacher_model(self):
324288 "draft_vocab_size" : model .config .vocab_size ,
325289 }
326290 )
327- self .args .draft_vocab_size = model .config .vocab_size
328291 mtsp .convert (model , [("eagle" , self .args .eagle_config )])
329292 model .eval ()
330293 return model
331294
332- def prepare_student_model (self ):
295+ def _prepare_student_model (self ):
333296 # Load to CPU first to avoid OOM
334297 model = AutoModelForCausalLM .from_pretrained (
335298 self .args .model_path , torch_dtype = "auto" , device_map = "cpu"
@@ -342,7 +305,6 @@ def prepare_student_model(self):
342305 "draft_vocab_size" : model .config .vocab_size ,
343306 }
344307 )
345- self .args .draft_vocab_size = model .config .vocab_size
346308 mtsp .convert (
347309 model ,
348310 [("eagle" , self .args .eagle_config )],
@@ -361,6 +323,42 @@ def prepare_student_model(self):
361323 )
362324 return model
363325
326+ @property
327+ def distill_metadata (self ) -> DistillMetadata :
328+ """Description of the distillation signal received by student."""
329+ return {
330+ "base_model_hidden_states" : (
331+ torch .Size (
332+ [
333+ int (self .args .batch_size / len (self .args .student_ranks )),
334+ self .args .training_seq_len ,
335+ self .args .eagle_config ["eagle_architecture_config" ]["hidden_size" ],
336+ ]
337+ ),
338+ torch .bfloat16 ,
339+ ),
340+ "aux_hidden_states" : (
341+ torch .Size (
342+ [
343+ int (self .args .batch_size / len (self .args .student_ranks )),
344+ self .args .training_seq_len ,
345+ self .args .eagle_config ["eagle_architecture_config" ]["hidden_size" ] * 3 ,
346+ ]
347+ ),
348+ torch .bfloat16 ,
349+ ),
350+ "base_model_logits" : (
351+ torch .Size (
352+ [
353+ int (self .args .batch_size / len (self .args .student_ranks )),
354+ self .args .training_seq_len ,
355+ self .args .eagle_config ["eagle_architecture_config" ]["draft_vocab_size" ],
356+ ]
357+ ),
358+ torch .bfloat16 ,
359+ ),
360+ }
361+
364362 def teacher_step (self , model , inputs ):
365363 # Collect base model outputs.
366364 base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
0 commit comments