Skip to content

Commit 8d6a49b

Browse files
committed
polish
Signed-off-by: h-guo18 <[email protected]>
1 parent 35cc9a8 commit 8d6a49b

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ def __init__(self, rank, args, tokenizer, dataloader):
6666

6767
# Prepare models
6868
if rank in args.student_ranks:
69-
self.model = self.prepare_student_model()
69+
self.model = self._prepare_student_model()
7070
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
7171
self.scheduler = get_linear_schedule_with_warmup(
7272
self.optimizer, num_warmup_steps=0, num_training_steps=117380
7373
)
7474
else:
75-
self.model = self.prepare_teacher_model()
75+
self.model = self._prepare_teacher_model()
7676
self._print_model_placement(self.model)
7777

7878
def _print_model_placement(self, module):
@@ -95,11 +95,11 @@ def distill_metadata(self):
9595
"""Return a DistillMetadata that describe the distillation message received by student."""
9696

9797
@abstractmethod
98-
def prepare_teacher_model(self):
98+
def _prepare_teacher_model(self):
9999
"""Return coverted teacher model with correct parallelization."""
100100

101101
@abstractmethod
102-
def prepare_student_model(self):
102+
def _prepare_student_model(self):
103103
"""Return coverted student model with correct parallelization."""
104104

105105
@abstractmethod
@@ -272,43 +272,7 @@ def current_rank_device(self):
272272
else:
273273
return self.args.teacher_devices[self.rank - len(self.args.student_ranks)]
274274

275-
@property
276-
def distill_metadata(self) -> DistillMetadata:
277-
"""Description of the distillation signal received by student."""
278-
return {
279-
"base_model_hidden_states": (
280-
torch.Size(
281-
[
282-
int(self.args.batch_size / len(self.args.student_ranks)),
283-
self.args.training_seq_len,
284-
2048,
285-
]
286-
),
287-
torch.bfloat16,
288-
),
289-
"aux_hidden_states": (
290-
torch.Size(
291-
[
292-
int(self.args.batch_size / len(self.args.student_ranks)),
293-
self.args.training_seq_len,
294-
2048 * 3,
295-
]
296-
),
297-
torch.bfloat16,
298-
),
299-
"base_model_logits": (
300-
torch.Size(
301-
[
302-
int(self.args.batch_size / len(self.args.student_ranks)),
303-
self.args.training_seq_len,
304-
self.args.draft_vocab_size,
305-
]
306-
),
307-
torch.bfloat16,
308-
),
309-
}
310-
311-
def prepare_teacher_model(self):
275+
def _prepare_teacher_model(self):
312276
# Load model with TP among teacher ranks.
313277
model = AutoModelForCausalLM.from_pretrained(
314278
self.args.model_path,
@@ -324,12 +288,11 @@ def prepare_teacher_model(self):
324288
"draft_vocab_size": model.config.vocab_size,
325289
}
326290
)
327-
self.args.draft_vocab_size = model.config.vocab_size
328291
mtsp.convert(model, [("eagle", self.args.eagle_config)])
329292
model.eval()
330293
return model
331294

332-
def prepare_student_model(self):
295+
def _prepare_student_model(self):
333296
# Load to CPU first to avoid OOM
334297
model = AutoModelForCausalLM.from_pretrained(
335298
self.args.model_path, torch_dtype="auto", device_map="cpu"
@@ -342,7 +305,6 @@ def prepare_student_model(self):
342305
"draft_vocab_size": model.config.vocab_size,
343306
}
344307
)
345-
self.args.draft_vocab_size = model.config.vocab_size
346308
mtsp.convert(
347309
model,
348310
[("eagle", self.args.eagle_config)],
@@ -361,6 +323,42 @@ def prepare_student_model(self):
361323
)
362324
return model
363325

326+
@property
327+
def distill_metadata(self) -> DistillMetadata:
328+
"""Description of the distillation signal received by student."""
329+
return {
330+
"base_model_hidden_states": (
331+
torch.Size(
332+
[
333+
int(self.args.batch_size / len(self.args.student_ranks)),
334+
self.args.training_seq_len,
335+
self.args.eagle_config["eagle_architecture_config"]["hidden_size"],
336+
]
337+
),
338+
torch.bfloat16,
339+
),
340+
"aux_hidden_states": (
341+
torch.Size(
342+
[
343+
int(self.args.batch_size / len(self.args.student_ranks)),
344+
self.args.training_seq_len,
345+
self.args.eagle_config["eagle_architecture_config"]["hidden_size"] * 3,
346+
]
347+
),
348+
torch.bfloat16,
349+
),
350+
"base_model_logits": (
351+
torch.Size(
352+
[
353+
int(self.args.batch_size / len(self.args.student_ranks)),
354+
self.args.training_seq_len,
355+
self.args.eagle_config["eagle_architecture_config"]["draft_vocab_size"],
356+
]
357+
),
358+
torch.bfloat16,
359+
),
360+
}
361+
364362
def teacher_step(self, model, inputs):
365363
# Collect base model outputs.
366364
base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(

0 commit comments

Comments
 (0)