@@ -66,13 +66,13 @@ def __init__(self, rank, args, tokenizer, dataloader):
66
66
67
67
# Prepare models
68
68
if rank in args .student_ranks :
69
- self .model = self .prepare_student_model ()
69
+ self .model = self ._prepare_student_model ()
70
70
self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
71
71
self .scheduler = get_linear_schedule_with_warmup (
72
72
self .optimizer , num_warmup_steps = 0 , num_training_steps = 117380
73
73
)
74
74
else :
75
- self .model = self .prepare_teacher_model ()
75
+ self .model = self ._prepare_teacher_model ()
76
76
self ._print_model_placement (self .model )
77
77
78
78
def _print_model_placement (self , module ):
@@ -95,11 +95,11 @@ def distill_metadata(self):
95
95
"""Return a DistillMetadata that describe the distillation message received by student."""
96
96
97
97
@abstractmethod
98
- def prepare_teacher_model (self ):
98
+ def _prepare_teacher_model (self ):
99
99
"""Return coverted teacher model with correct parallelization."""
100
100
101
101
@abstractmethod
102
- def prepare_student_model (self ):
102
+ def _prepare_student_model (self ):
103
103
"""Return coverted student model with correct parallelization."""
104
104
105
105
@abstractmethod
@@ -272,43 +272,7 @@ def current_rank_device(self):
272
272
else :
273
273
return self .args .teacher_devices [self .rank - len (self .args .student_ranks )]
274
274
275
- @property
276
- def distill_metadata (self ) -> DistillMetadata :
277
- """Description of the distillation signal received by student."""
278
- return {
279
- "base_model_hidden_states" : (
280
- torch .Size (
281
- [
282
- int (self .args .batch_size / len (self .args .student_ranks )),
283
- self .args .training_seq_len ,
284
- 2048 ,
285
- ]
286
- ),
287
- torch .bfloat16 ,
288
- ),
289
- "aux_hidden_states" : (
290
- torch .Size (
291
- [
292
- int (self .args .batch_size / len (self .args .student_ranks )),
293
- self .args .training_seq_len ,
294
- 2048 * 3 ,
295
- ]
296
- ),
297
- torch .bfloat16 ,
298
- ),
299
- "base_model_logits" : (
300
- torch .Size (
301
- [
302
- int (self .args .batch_size / len (self .args .student_ranks )),
303
- self .args .training_seq_len ,
304
- self .args .draft_vocab_size ,
305
- ]
306
- ),
307
- torch .bfloat16 ,
308
- ),
309
- }
310
-
311
- def prepare_teacher_model (self ):
275
+ def _prepare_teacher_model (self ):
312
276
# Load model with TP among teacher ranks.
313
277
model = AutoModelForCausalLM .from_pretrained (
314
278
self .args .model_path ,
@@ -324,12 +288,11 @@ def prepare_teacher_model(self):
324
288
"draft_vocab_size" : model .config .vocab_size ,
325
289
}
326
290
)
327
- self .args .draft_vocab_size = model .config .vocab_size
328
291
mtsp .convert (model , [("eagle" , self .args .eagle_config )])
329
292
model .eval ()
330
293
return model
331
294
332
- def prepare_student_model (self ):
295
+ def _prepare_student_model (self ):
333
296
# Load to CPU first to avoid OOM
334
297
model = AutoModelForCausalLM .from_pretrained (
335
298
self .args .model_path , torch_dtype = "auto" , device_map = "cpu"
@@ -342,7 +305,6 @@ def prepare_student_model(self):
342
305
"draft_vocab_size" : model .config .vocab_size ,
343
306
}
344
307
)
345
- self .args .draft_vocab_size = model .config .vocab_size
346
308
mtsp .convert (
347
309
model ,
348
310
[("eagle" , self .args .eagle_config )],
@@ -361,6 +323,42 @@ def prepare_student_model(self):
361
323
)
362
324
return model
363
325
326
+ @property
327
+ def distill_metadata (self ) -> DistillMetadata :
328
+ """Description of the distillation signal received by student."""
329
+ return {
330
+ "base_model_hidden_states" : (
331
+ torch .Size (
332
+ [
333
+ int (self .args .batch_size / len (self .args .student_ranks )),
334
+ self .args .training_seq_len ,
335
+ self .args .eagle_config ["eagle_architecture_config" ]["hidden_size" ],
336
+ ]
337
+ ),
338
+ torch .bfloat16 ,
339
+ ),
340
+ "aux_hidden_states" : (
341
+ torch .Size (
342
+ [
343
+ int (self .args .batch_size / len (self .args .student_ranks )),
344
+ self .args .training_seq_len ,
345
+ self .args .eagle_config ["eagle_architecture_config" ]["hidden_size" ] * 3 ,
346
+ ]
347
+ ),
348
+ torch .bfloat16 ,
349
+ ),
350
+ "base_model_logits" : (
351
+ torch .Size (
352
+ [
353
+ int (self .args .batch_size / len (self .args .student_ranks )),
354
+ self .args .training_seq_len ,
355
+ self .args .eagle_config ["eagle_architecture_config" ]["draft_vocab_size" ],
356
+ ]
357
+ ),
358
+ torch .bfloat16 ,
359
+ ),
360
+ }
361
+
364
362
def teacher_step (self , model , inputs ):
365
363
# Collect base model outputs.
366
364
base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
0 commit comments