Skip to content

Commit 6ec1c4b

Browse files
committed
add new trainer
Signed-off-by: h-guo18 <[email protected]>
1 parent 2e822c6 commit 6ec1c4b

File tree

2 files changed

+528
-0
lines changed

2 files changed

+528
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
17+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
18+
from abc import abstractmethod
19+
20+
import torch
21+
import torch.distributed as dist
22+
from tqdm import tqdm
23+
24+
import modelopt.torch.opt as mto
25+
26+
mto.enable_huggingface_checkpointing()
27+
28+
# Hyperparameters for profiling
29+
EPOCHS = 20
30+
LOG_INTERVAL = 25
31+
SAVE_INTERVAL = 20000
32+
# VALIDATE_INTERVAL = 20
33+
34+
# We define the distill signal from teacher as the map of variable name to its shape and dtype.
35+
DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]]
36+
37+
38+
class BaseDistillTrainer:
39+
"""
40+
Base class for distillation trainer. Initalized and called on every rank.
41+
Args:
42+
rank: rank of the current process
43+
args: arguments
44+
teacher_step: teacher step function.
45+
student_step: student step function.
46+
"""
47+
48+
def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata):
49+
self.rank = rank
50+
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
51+
self.args = args
52+
self.tokenizer = tokenizer
53+
self.distill_metadata = distill_metadata
54+
55+
def _print_model_placement(self, module):
56+
for name, param in module.named_parameters():
57+
print(f"(Rank {self.rank}) {name} ---> {param.device} ")
58+
59+
@property
60+
def current_rank_devices(self):
61+
pass
62+
63+
def _reset_all_mem_stats(self):
64+
for d in self.current_rank_devices:
65+
torch.cuda.reset_max_memory_allocated(d)
66+
67+
def _print_mem_stats(self):
68+
for d in self.current_rank_devices:
69+
max_mem = torch.cuda.max_memory_allocated(d)
70+
print(f"GPU {d}: Max memory allocated: {max_mem / 1024**3:.2f} GB")
71+
72+
@abstractmethod
73+
def load_teacher_model(self):
74+
pass
75+
76+
@abstractmethod
77+
def load_student_model(self):
78+
pass
79+
80+
@abstractmethod
81+
def teacher_step(self, *args, **kwargs) -> dict[str, torch.Tensor]:
82+
pass
83+
84+
@abstractmethod
85+
def student_step(self, *args, **kwargs):
86+
pass
87+
88+
def save_pretrained(self, path=None):
89+
if self.rank == self.args.student_rank:
90+
path = self.args.out_path if path is None else path
91+
self.model.save_pretrained(path)
92+
self.tokenizer.save_pretrained(path)
93+
print(f"Pretrained model saved to {path}")
94+
95+
def _check_valid_message(self, message: dict[str, torch.Tensor]):
96+
# Check if keys and length match between message and distill_metadata
97+
if set(message.keys()) != set(self.distill_metadata.keys()):
98+
raise ValueError(
99+
f"Message keys from teacher: {set(message.keys())} \n"
100+
f"do not match expected keys {set(self.distill_metadata.keys())}"
101+
)
102+
if len(message) != len(self.distill_metadata):
103+
raise ValueError(
104+
f"Message length from teacher: {len(message)} \n"
105+
f"does not match expected {len(self.distill_metadata)}"
106+
)
107+
for k, v in message.items():
108+
if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]:
109+
raise ValueError(
110+
f"Invalid message from teacher. {k} has shape {v.shape} and dtype {v.dtype}, \n"
111+
f"expected {self.distill_metadata[k]}"
112+
)
113+
114+
def _init_student_recv_buffer(self):
115+
self.student_recv_buffer = {
116+
k: torch.empty(v[0], device=self.args.student_device, dtype=v[1])
117+
for k, v in self.distill_metadata.items()
118+
}
119+
120+
def _recv_from_teacher(self):
121+
reqs = [
122+
dist.irecv(buffer, src=self.args.teacher_ranks[0])
123+
for buffer in self.student_recv_buffer.values()
124+
]
125+
for req in reqs:
126+
req.wait()
127+
128+
def _get_distill_kwargs(self):
129+
return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()}
130+
131+
def _send_to_student(self, teacher_outputs):
132+
if self.rank != self.args.teacher_ranks[0]:
133+
return
134+
self._check_valid_message(teacher_outputs)
135+
reqs = [
136+
dist.isend(buffer, dst=self.args.student_rank) for buffer in teacher_outputs.values()
137+
]
138+
for req in reqs:
139+
req.wait()
140+
141+
# def _validate_ar(self, steps=3, osl=20, num_samples=20):
142+
# if self.rank != self.args.student_rank:
143+
# return
144+
# # Load MT-Bench prompts from HuggingFace
145+
# ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
146+
# self.model.eval()
147+
# self.model.to(self.args.student_device)
148+
# ars = validate_ar(
149+
# self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device
150+
# )
151+
# # Print results
152+
# avg_ar = sum(ars) / len(ars)
153+
# print("\n==== AR Validation Results on MT-Bench ====")
154+
# print(f"Number of samples: {len(ars)}")
155+
# print(f"Output Sequence Length: {osl}")
156+
# print(f"Steps: {steps}")
157+
# print(f"Average AR: {avg_ar:.4f}")
158+
# self.model.train()
159+
160+
def train(self, dataloader):
161+
"""Main training entrance of the composed model."""
162+
self._reset_all_mem_stats()
163+
164+
if self.rank == self.args.student_rank:
165+
import wandb
166+
167+
wandb.login()
168+
169+
with wandb.init(
170+
entity=os.environ["WANDB_ENTITY"],
171+
project=os.environ["WANDB_PROJECT"],
172+
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
173+
) as run:
174+
self.model, self.optimizer = self.load_student_model()
175+
self._init_student_recv_buffer()
176+
wandb.watch(self.model, log="all")
177+
178+
for epoch in range(EPOCHS):
179+
pbar = tqdm(dataloader)
180+
for i, batch in enumerate(pbar):
181+
global_step = epoch * len(dataloader) + i
182+
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
183+
self._recv_from_teacher()
184+
loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs())
185+
pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}")
186+
187+
if global_step % LOG_INTERVAL == 0:
188+
run.log(
189+
{
190+
"loss": loss,
191+
"train_acc_step0": train_acc[0],
192+
"train_acc_step1": train_acc[1],
193+
"train_acc_step2": train_acc[2],
194+
"train_acc_step3": train_acc[3],
195+
},
196+
step=global_step,
197+
)
198+
199+
# This is not working for some reason.
200+
# if global_step > 0 and global_step % VALIDATE_INTERVAL == 0:
201+
# self._validate_ar()
202+
203+
if global_step > 0 and global_step % SAVE_INTERVAL == 0:
204+
self.save_pretrained(
205+
f"{self.args.out_path}/epoch_{epoch}_step_{global_step}"
206+
)
207+
208+
else:
209+
self.model = self.load_teacher_model()
210+
# Inference Loop
211+
for epoch in range(EPOCHS):
212+
for i, batch in enumerate(dataloader):
213+
global_step = epoch * len(dataloader) + i
214+
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
215+
inputs["position_ids"] = None
216+
with torch.inference_mode():
217+
teacher_outputs = self.teacher_step(self.model, inputs)
218+
self._send_to_student(teacher_outputs)
219+
220+
self._print_mem_stats()
221+
# Makesure all processes finished before destroy.
222+
dist.barrier()
223+
# clean up processess
224+
dist.destroy_process_group()

0 commit comments

Comments
 (0)