Skip to content

Commit 1a63254

Browse files
committed
clean up codes
Signed-off-by: h-guo18 <[email protected]>
1 parent 5ae4479 commit 1a63254

File tree

3 files changed

+111
-95
lines changed

3 files changed

+111
-95
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 97 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,26 @@
1717

1818
os.environ["TOKENIZERS_PARALLELISM"] = "false"
1919
from abc import abstractmethod
20+
from contextlib import nullcontext
2021

2122
import torch
2223
import torch.distributed as dist
2324
from torch.distributed.device_mesh import DeviceMesh
2425
from tqdm import tqdm
2526
from transformers import AutoModelForCausalLM
2627
from transformers.optimization import get_linear_schedule_with_warmup
28+
from transformers.utils import ModelOutput
2729

2830
import modelopt.torch.opt as mto
2931
import modelopt.torch.speculative as mtsp
3032
from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
3133

34+
try:
35+
import wandb
36+
except ImportError:
37+
wandb = None
38+
39+
3240
mto.enable_huggingface_checkpointing()
3341

3442
# Hyperparameters for profiling
@@ -51,12 +59,13 @@ class BaseDistillTrainer:
5159
student_step: student step function.
5260
"""
5361

54-
def __init__(self, rank, args, tokenizer):
62+
def __init__(self, rank, args, tokenizer, dataloader):
5563
self.rank = rank
5664
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
5765
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
5866
self.args = args
5967
self.tokenizer = tokenizer
68+
self.dataloader = dataloader
6069
if rank in args.student_ranks:
6170
self.model = self.prepare_student_model()
6271
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
@@ -71,46 +80,49 @@ def _print_model_placement(self, module):
7180
for name, param in module.named_parameters():
7281
print(f"(Rank {self.rank}) {name} ---> {param.device} ")
7382

74-
@property
75-
def current_rank_device(self):
76-
pass
77-
78-
@property
79-
def distill_metadata(self):
80-
pass
81-
8283
def _reset_all_mem_stats(self):
8384
torch.cuda.reset_max_memory_allocated(self.current_rank_device)
8485

8586
def _print_mem_stats(self):
8687
max_mem = torch.cuda.max_memory_allocated(self.current_rank_device)
8788
print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB")
8889

90+
@property
91+
def current_rank_device(self):
92+
"""Return device of the current rank."""
93+
94+
@property
95+
def distill_metadata(self):
96+
"""Return a DistillMetadata that describe the distillation message received by student."""
97+
8998
@abstractmethod
90-
def load_teacher_model(self):
91-
pass
99+
def prepare_teacher_model(self):
100+
"""Return coverted teacher model with correct parallelization."""
92101

93102
@abstractmethod
94-
def load_student_model(self):
95-
pass
103+
def prepare_student_model(self):
104+
"""Return coverted student model with correct parallelization."""
96105

97106
@abstractmethod
98-
def teacher_step(self, *args, **kwargs) -> dict[str, torch.Tensor]:
99-
pass
107+
def teacher_step(self, *args, **kwargs) -> list[dict[str, torch.Tensor]]:
108+
"""Run one student step and return distillation messages for each student rank."""
100109

101110
@abstractmethod
102-
def student_step(self, *args, **kwargs):
103-
pass
111+
def student_step(self, *args, **kwargs) -> ModelOutput:
112+
"""Run forward of student step, return a modeloutput object."""
104113

105-
def save_pretrained(self, path=None):
114+
def save_pretrained(self, save_path):
115+
"""Save the model and tokenizer."""
106116
if self.rank == self.args.student_ranks[0]:
107-
path = self.args.out_path if path is None else path
108-
self.model.save_pretrained(path)
109-
self.tokenizer.save_pretrained(path)
110-
print(f"Pretrained model saved to {path}")
117+
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
118+
self.model.module.save_pretrained(save_path)
119+
else:
120+
self.model.save_pretrained(save_path)
121+
self.tokenizer.save_pretrained(save_path)
122+
print(f"Pretrained model saved to {save_path}")
111123

112124
def _check_valid_message(self, message: dict[str, torch.Tensor]):
113-
# Check if keys and length match between message and distill_metadata
125+
"""Check if message in the format of distill_metadata."""
114126
if set(message.keys()) != set(self.distill_metadata.keys()):
115127
raise ValueError(
116128
f"Message keys: {set(message.keys())} \n"
@@ -142,8 +154,8 @@ def _recv_from_teacher(self):
142154
for req in reqs:
143155
req.wait()
144156

145-
def _get_distill_kwargs(self):
146-
"""Return a copy of received buffer for student training."""
157+
def _clone_recv_buffer(self):
158+
"""Return a copy of received tensors for student step input."""
147159
return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()}
148160

149161
def _send_to_student(self, teacher_outputs):
@@ -160,49 +172,63 @@ def _send_to_student(self, teacher_outputs):
160172
for req in reqs:
161173
req.wait()
162174

163-
def train(self, dataloader):
175+
def _get_logging_context(self):
176+
print(
177+
f"Rank {self.rank} is logging: {wandb is not None and self.rank == self.args.student_ranks[0]}"
178+
)
179+
if wandb is not None and self.rank == self.args.student_ranks[0]:
180+
return wandb.init(
181+
entity=os.environ["WANDB_ENTITY"],
182+
project=os.environ["WANDB_PROJECT"],
183+
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
184+
)
185+
return nullcontext()
186+
187+
def train(self):
164188
"""Main training entrance of the composed model."""
165189
self._reset_all_mem_stats()
166190

167191
if self.rank in self.args.student_ranks:
168-
import wandb
169-
170-
wandb.login()
171-
172-
with wandb.init(
173-
entity=os.environ["WANDB_ENTITY"],
174-
project=os.environ["WANDB_PROJECT"],
175-
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
176-
) as run:
192+
with self._get_logging_context() as run:
177193
self._init_student_recv_buffer()
178-
wandb.watch(self.model, log="all")
179194

195+
# Student training loop
180196
for epoch in range(EPOCHS):
181197
pbar = (
182-
tqdm(dataloader) if self.rank == self.args.student_ranks[0] else dataloader
198+
tqdm(self.dataloader)
199+
if self.rank == self.args.student_ranks[0]
200+
else self.dataloader
183201
)
184202
for i, batch in enumerate(pbar):
185-
global_step = epoch * len(dataloader) + i
203+
global_step = epoch * len(self.dataloader) + i
186204
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
205+
206+
# Receive distill messages from teacher
187207
self._recv_from_teacher()
188-
loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs())
189208

209+
# Run forward of student step
210+
output = self.student_step(inputs, **self._clone_recv_buffer())
211+
loss = output.loss
212+
213+
# Run backward step
214+
loss.backward()
215+
self.optimizer.step()
216+
self.scheduler.step()
217+
218+
# Log and save only on student rank 0
190219
if self.rank != self.args.student_ranks[0]:
191220
continue
192221

193-
pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}")
222+
train_metrics = {
223+
"loss": round(loss.item(), 3),
224+
"lr": self.optimizer.param_groups[0]["lr"],
225+
# Attach all float metrics
226+
**{k: round(v, 3) for k, v in output.items() if isinstance(v, float)},
227+
}
228+
229+
pbar.set_description(f"Epoch {epoch} Loss {train_metrics['loss']}")
194230
if global_step % LOG_INTERVAL == 0:
195-
run.log(
196-
{
197-
"loss": loss,
198-
"train_acc_step0": train_acc[0],
199-
"train_acc_step1": train_acc[1],
200-
"train_acc_step2": train_acc[2],
201-
"train_acc_step3": train_acc[3],
202-
"lr": self.optimizer.param_groups[0]["lr"],
203-
},
204-
step=global_step,
205-
)
231+
run.log(train_metrics, step=global_step)
206232
if global_step > 0 and global_step % SAVE_INTERVAL == 0:
207233
self.save_pretrained(
208234
f"{self.args.out_path}/epoch_{epoch}_step_{global_step}"
@@ -211,13 +237,10 @@ def train(self, dataloader):
211237
else:
212238
# Inference Loop
213239
for epoch in range(EPOCHS):
214-
for i, batch in enumerate(dataloader):
215-
global_step = epoch * len(dataloader) + i
240+
for i, batch in enumerate(self.dataloader):
216241
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
217-
inputs["position_ids"] = None
218242
with torch.inference_mode():
219-
teacher_outputs = self.teacher_step(self.model, inputs)
220-
self._send_to_student(teacher_outputs)
243+
self._send_to_student(self.teacher_step(self.model, inputs))
221244

222245
self._print_mem_stats()
223246
# Makesure all processes finished before destroy.
@@ -227,14 +250,15 @@ def train(self, dataloader):
227250

228251

229252
class EagleTPTrainer(BaseDistillTrainer):
230-
def __init__(self, rank, args, tokenizer):
253+
def __init__(self, rank, args, tokenizer, dataloader):
254+
# Load eagle config
231255
args.eagle_config = EAGLE3_DEFAULT_CFG["config"]
232256
if args.eagle_config_path:
233257
with open(args.eagle_config_path) as f:
234258
custom_config = json.load(f)
235259
args.eagle_config["eagle_architecture_config"].update(custom_config)
236260

237-
super().__init__(rank, args, tokenizer)
261+
super().__init__(rank, args, tokenizer, dataloader)
238262

239263
@property
240264
def current_rank_device(self):
@@ -245,6 +269,7 @@ def current_rank_device(self):
245269

246270
@property
247271
def distill_metadata(self) -> DistillMetadata:
272+
"""Description of the distillation signal received by student."""
248273
return {
249274
"base_model_hidden_states": (
250275
torch.Size(
@@ -279,12 +304,14 @@ def distill_metadata(self) -> DistillMetadata:
279304
}
280305

281306
def prepare_teacher_model(self):
307+
# Load model with TP among teacher ranks.
282308
model = AutoModelForCausalLM.from_pretrained(
283309
self.args.model_path,
284310
torch_dtype="auto",
285311
tp_plan="auto",
286312
device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"),
287313
)
314+
# load eagle config and convert.
288315
self.args.eagle_config["eagle_architecture_config"].update(
289316
{
290317
"hidden_size": model.config.hidden_size,
@@ -298,7 +325,6 @@ def prepare_teacher_model(self):
298325
return model
299326

300327
def prepare_student_model(self):
301-
"""Load student model on a single device and keep needed modules from teacher."""
302328
# Load to CPU first to avoid OOM
303329
model = AutoModelForCausalLM.from_pretrained(
304330
self.args.model_path, torch_dtype="auto", device_map="cpu"
@@ -331,15 +357,19 @@ def prepare_student_model(self):
331357
return model
332358

333359
def teacher_step(self, model, inputs):
360+
# Collect base model outputs.
334361
base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(
335362
**inputs,
336363
freeze_base_model=True,
337364
past_key_values=None,
338365
)
339-
# aux_hidden_states could be on multiple devices. Gather them and cat.
366+
367+
# Aux_hidden_states could be on multiple devices. Gather before cat.
340368
aux_hidden_states = torch.cat(
341369
[t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1
342370
)
371+
372+
# Chunk the tensors for each student rank.
343373
base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks))
344374
base_model_logits = base_model_logits.chunk(len(self.args.student_ranks))
345375
aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks))
@@ -356,28 +386,12 @@ def teacher_step(self, model, inputs):
356386
def student_step(
357387
self,
358388
inputs,
359-
base_model_hidden_states,
360-
aux_hidden_states,
361-
base_model_logits,
362-
):
389+
**distill_msgs,
390+
) -> ModelOutput:
363391
self.optimizer.zero_grad()
364-
# Second stage forward using the unified model
392+
393+
# Chunk inputs for each student rank.
365394
inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()}
366-
output = self.model(
367-
**inputs,
368-
# providing base model outputs to bypass the base model forward.
369-
base_model_outputs={
370-
"base_model_hidden_states": base_model_hidden_states,
371-
"aux_hidden_states": aux_hidden_states.clone().detach(),
372-
"base_model_logits": base_model_logits.clone().detach(),
373-
},
374-
)
375-
loss = output.loss
376-
# print(f"Rank {self.rank} loss: {loss.item()}")
377-
train_acc = output.train_acc
378-
379-
# Backward
380-
loss.backward()
381-
self.optimizer.step()
382-
self.scheduler.step()
383-
return round(loss.item(), 3), train_acc
395+
396+
# Second stage forward with provided base model outputs.
397+
return self.model(**inputs, base_model_outputs=distill_msgs)

examples/speculative_decoding/train.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def train(rank, args):
6060
drop_last=True,
6161
)
6262

63-
trainer = EagleTPTrainer(rank, args, tokenizer)
64-
trainer.train(train_dataloader)
65-
# trainer.save_pretrained("ckpts/fast-trained")
63+
trainer = EagleTPTrainer(rank, args, tokenizer, train_dataloader)
64+
trainer.train()
65+
trainer.save_pretrained(args.out_path)
6666

6767

6868
def main():
@@ -104,7 +104,9 @@ def main():
104104
"--out_path", type=str, default="ckpts/fast-trained", help="Path to save the model."
105105
)
106106
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
107-
parser.add_argument("--batch_size", type=int, default=4, help="Batch size.")
107+
parser.add_argument(
108+
"--batch_size", type=int, default=4, help="Total batch size across all parallel ranks."
109+
)
108110
parser.add_argument("--master_port", type=str, default="12357", help="Master port.")
109111

110112
args = parser.parse_args()

0 commit comments

Comments
 (0)