Skip to content

Commit 6c62207

Browse files
committed
refactor
Signed-off-by: h-guo18 <[email protected]>
1 parent 895ceaf commit 6c62207

File tree

2 files changed

+181
-178
lines changed

2 files changed

+181
-178
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,25 @@
1919

2020
import torch
2121
import torch.distributed as dist
22+
from torch.distributed.device_mesh import DeviceMesh
2223
from tqdm import tqdm
24+
from transformers import AutoModelForCausalLM
25+
from transformers.optimization import get_linear_schedule_with_warmup
2326

2427
import modelopt.torch.opt as mto
28+
import modelopt.torch.speculative as mtsp
2529

2630
mto.enable_huggingface_checkpointing()
2731

2832
# Hyperparameters for profiling
29-
EPOCHS = 10
33+
EPOCHS = 1
3034
LOG_INTERVAL = 100
3135
SAVE_INTERVAL = 20000
36+
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37+
DRAFT_VOCAB_SIZE = 32000
3238
# VALIDATE_INTERVAL = 20
3339

34-
# We define the distill signal from teacher as the map of variable name to its shape and dtype.
40+
# Shape and dtype description of the distillation signal
3541
DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]]
3642

3743

@@ -208,3 +214,172 @@ def train(self, dataloader):
208214
dist.barrier()
209215
# clean up processess
210216
dist.destroy_process_group()
217+
218+
219+
class EagleTPTrainer(BaseDistillTrainer):
220+
@property
221+
def current_rank_device(self):
222+
if self.rank in self.args.student_ranks:
223+
return self.args.student_devices[self.rank]
224+
else:
225+
return self.args.teacher_devices[self.rank - len(self.args.student_ranks)]
226+
227+
def load_teacher_model(self):
228+
model = AutoModelForCausalLM.from_pretrained(
229+
MODEL_PATH,
230+
torch_dtype="auto",
231+
tp_plan="auto",
232+
device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"),
233+
)
234+
self.args.eagle_config["eagle_architecture_config"].update(
235+
{
236+
"hidden_size": model.config.hidden_size,
237+
"vocab_size": model.config.vocab_size,
238+
"draft_vocab_size": DRAFT_VOCAB_SIZE,
239+
}
240+
)
241+
mtsp.convert(model, [("eagle", self.args.eagle_config)])
242+
model.eval()
243+
self._print_model_placement(model)
244+
return model
245+
246+
def load_student_model(self):
247+
"""Load student model on a single device and keep needed modules from teacher."""
248+
# Load to CPU first to avoid OOM
249+
model = AutoModelForCausalLM.from_pretrained(
250+
MODEL_PATH, torch_dtype="auto", device_map="cpu"
251+
)
252+
# Hidden size and vocab size must match base model
253+
self.args.eagle_config["eagle_architecture_config"].update(
254+
{
255+
"hidden_size": model.config.hidden_size,
256+
"vocab_size": model.config.vocab_size,
257+
"draft_vocab_size": DRAFT_VOCAB_SIZE,
258+
}
259+
)
260+
mtsp.convert(
261+
model,
262+
[("eagle", self.args.eagle_config)],
263+
)
264+
if model.config.vocab_size > DRAFT_VOCAB_SIZE:
265+
model_name = os.path.basename(os.path.normpath(MODEL_PATH))
266+
vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
267+
try:
268+
vocab_cache = torch.load(vocab_cache_path)
269+
assert len(vocab_cache) == DRAFT_VOCAB_SIZE
270+
model.eagle_module.d2t = vocab_cache
271+
print(f"Loaded draft vocab cache from {vocab_cache_path}.")
272+
except Exception as e:
273+
raise e
274+
275+
# TODO:copy needed modules and del the rest
276+
model.model._modules.pop("layers")
277+
model.to(self.current_rank_device)
278+
279+
model.train()
280+
model = torch.nn.parallel.DistributedDataParallel(
281+
model,
282+
device_ids=[self.current_rank_device],
283+
process_group=self.args.student_pgroup,
284+
find_unused_parameters=True,
285+
)
286+
optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr)
287+
scheduler = get_linear_schedule_with_warmup(
288+
optimizer, num_warmup_steps=0, num_training_steps=117380
289+
)
290+
self._print_model_placement(model)
291+
return model, optimizer, scheduler
292+
293+
def teacher_step(self, model, inputs):
294+
base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(
295+
**inputs,
296+
freeze_base_model=True,
297+
past_key_values=None,
298+
)
299+
# aux_hidden_states could be on multiple devices. Gather them and cat.
300+
aux_hidden_states = torch.cat(
301+
[t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1
302+
)
303+
base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks))
304+
base_model_logits = base_model_logits.chunk(len(self.args.student_ranks))
305+
aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks))
306+
307+
return [
308+
{
309+
"base_model_hidden_states": base_model_hidden_states[i],
310+
"aux_hidden_states": aux_hidden_states[i],
311+
"base_model_logits": base_model_logits[i],
312+
}
313+
for i in range(len(self.args.student_ranks))
314+
]
315+
316+
def student_step(
317+
self,
318+
inputs,
319+
base_model_hidden_states,
320+
aux_hidden_states,
321+
base_model_logits,
322+
):
323+
self.optimizer.zero_grad()
324+
# Second stage forward using the unified model
325+
inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()}
326+
output = self.model(
327+
**inputs,
328+
# providing base model outputs to bypass the base model forward.
329+
base_model_outputs={
330+
"base_model_hidden_states": base_model_hidden_states,
331+
"aux_hidden_states": aux_hidden_states.clone().detach(),
332+
"base_model_logits": base_model_logits.clone().detach(),
333+
},
334+
)
335+
loss = output.loss
336+
# print(f"Rank {self.rank} loss: {loss.item()}")
337+
train_acc = output.train_acc
338+
339+
# Backward
340+
loss.backward()
341+
self.optimizer.step()
342+
self.scheduler.step()
343+
return round(loss.item(), 3), train_acc
344+
345+
346+
# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer):
347+
# @property
348+
# def current_rank_devices(self):
349+
# if self.rank == self.args.student_rank:
350+
# return [self.args.student_device]
351+
# else:
352+
# return self.args.teacher_devices
353+
354+
# def load_teacher_model(self):
355+
# model = AutoModelForCausalLM.from_pretrained(
356+
# MODEL_PATH,
357+
# torch_dtype="auto",
358+
# device_map="sequential",
359+
# max_memory=dict.fromkeys(
360+
# self.args.teacher_devices, "999GiB"
361+
# ), # To use only given devices
362+
# )
363+
# self.args.eagle_config["eagle_architecture_config"].update(
364+
# {
365+
# "hidden_size": model.config.hidden_size,
366+
# "vocab_size": model.config.vocab_size,
367+
# "draft_vocab_size": DRAFT_VOCAB_SIZE,
368+
# }
369+
# )
370+
# mtsp.convert(model, [("eagle", self.args.eagle_config)])
371+
372+
# if model.config.vocab_size > DRAFT_VOCAB_SIZE:
373+
# model_name = os.path.basename(os.path.normpath(MODEL_PATH))
374+
# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
375+
# try:
376+
# vocab_cache = torch.load(vocab_cache_path)
377+
# assert len(vocab_cache) == DRAFT_VOCAB_SIZE
378+
# model.eagle_module.d2t = vocab_cache
379+
# print(f"Loaded draft vocab cache from {vocab_cache_path}.")
380+
# except Exception as e:
381+
# raise e
382+
383+
# model.eval()
384+
# self._print_model_placement(model)
385+
# return model

0 commit comments

Comments
 (0)