Skip to content

Commit ddf5d5d

Browse files
committed
polish
Signed-off-by: h-guo18 <[email protected]>
1 parent 0e70be4 commit ddf5d5d

File tree

2 files changed

+25
-36
lines changed

2 files changed

+25
-36
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828
# Hyperparameters for profiling
29-
EPOCHS = 1
30-
LOG_INTERVAL = 1
29+
EPOCHS = 10
30+
LOG_INTERVAL = 100
3131
SAVE_INTERVAL = 20000
3232
# VALIDATE_INTERVAL = 20
3333

@@ -125,6 +125,7 @@ def _recv_from_teacher(self):
125125
req.wait()
126126

127127
def _get_distill_kwargs(self):
128+
"""Return a copy of received buffer for student training."""
128129
return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()}
129130

130131
def _send_to_student(self, teacher_outputs):
@@ -141,25 +142,6 @@ def _send_to_student(self, teacher_outputs):
141142
for req in reqs:
142143
req.wait()
143144

144-
# def _validate_ar(self, steps=3, osl=20, num_samples=20):
145-
# if self.rank != self.args.student_rank:
146-
# return
147-
# # Load MT-Bench prompts from HuggingFace
148-
# ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
149-
# self.model.eval()
150-
# self.model.to(self.args.student_device)
151-
# ars = validate_ar(
152-
# self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device
153-
# )
154-
# # Print results
155-
# avg_ar = sum(ars) / len(ars)
156-
# print("\n==== AR Validation Results on MT-Bench ====")
157-
# print(f"Number of samples: {len(ars)}")
158-
# print(f"Output Sequence Length: {osl}")
159-
# print(f"Steps: {steps}")
160-
# print(f"Average AR: {avg_ar:.4f}")
161-
# self.model.train()
162-
163145
def train(self, dataloader):
164146
"""Main training entrance of the composed model."""
165147
self._reset_all_mem_stats()
@@ -174,19 +156,24 @@ def train(self, dataloader):
174156
project=os.environ["WANDB_PROJECT"],
175157
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
176158
) as run:
177-
self.model, self.optimizer = self.load_student_model()
159+
self.model, self.optimizer, self.scheduler = self.load_student_model()
178160
self._init_student_recv_buffer()
179161
wandb.watch(self.model, log="all")
180162

181163
for epoch in range(EPOCHS):
182-
pbar = tqdm(dataloader)
164+
pbar = (
165+
tqdm(dataloader) if self.rank == self.args.student_ranks[0] else dataloader
166+
)
183167
for i, batch in enumerate(pbar):
184168
global_step = epoch * len(dataloader) + i
185169
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
186170
self._recv_from_teacher()
187171
loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs())
188-
pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}")
189172

173+
if self.rank != self.args.student_ranks[0]:
174+
continue
175+
176+
pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}")
190177
if global_step % LOG_INTERVAL == 0:
191178
run.log(
192179
{
@@ -195,14 +182,10 @@ def train(self, dataloader):
195182
"train_acc_step1": train_acc[1],
196183
"train_acc_step2": train_acc[2],
197184
"train_acc_step3": train_acc[3],
185+
"lr": self.optimizer.param_groups[0]["lr"],
198186
},
199187
step=global_step,
200188
)
201-
202-
# This is not working for some reason.
203-
# if global_step > 0 and global_step % VALIDATE_INTERVAL == 0:
204-
# self._validate_ar()
205-
206189
if global_step > 0 and global_step % SAVE_INTERVAL == 0:
207190
self.save_pretrained(
208191
f"{self.args.out_path}/epoch_{epoch}_step_{global_step}"

examples/speculative_decoding/train.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,19 @@
2323
from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module
2424
from torch.distributed.device_mesh import DeviceMesh
2525
from transformers import AutoModelForCausalLM, AutoTokenizer
26+
from transformers.optimization import get_linear_schedule_with_warmup
2627

2728
import modelopt.torch.speculative as mtsp
2829
from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
2930

3031
# Hyperparameters for profiling
3132
torch.manual_seed(0)
32-
INPUT_LENGTH = 512
33-
DRAFT_VOCAB_SIZE = 128256
34-
# DRAFT_VOCAB_SIZE = 32000
33+
INPUT_LENGTH = 1024
34+
# DRAFT_VOCAB_SIZE = 128256
35+
DRAFT_VOCAB_SIZE = 32000
3536
# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct"
36-
MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.2-1B-Instruct"
37+
# MODEL_PATH = "/lustre/fsw/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local/meta-llama/Llama-3.2-1B-Instruct"
38+
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
3739
# MODEL_PATH = "openai/gpt-oss-20b"
3840
# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.3-70B-Instruct"
3941

@@ -121,9 +123,12 @@ def load_student_model(self):
121123
process_group=self.args.student_pgroup,
122124
find_unused_parameters=True,
123125
)
124-
optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr)
126+
optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr)
127+
scheduler = get_linear_schedule_with_warmup(
128+
optimizer, num_warmup_steps=0, num_training_steps=117380
129+
)
125130
self._print_model_placement(model)
126-
return model, optimizer
131+
return model, optimizer, scheduler
127132

128133
def teacher_step(self, model, inputs):
129134
base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(
@@ -168,12 +173,13 @@ def student_step(
168173
},
169174
)
170175
loss = output.loss
171-
print(f"Rank {self.rank} loss: {loss.item()}")
176+
# print(f"Rank {self.rank} loss: {loss.item()}")
172177
train_acc = output.train_acc
173178

174179
# Backward
175180
loss.backward()
176181
self.optimizer.step()
182+
self.scheduler.step()
177183
return round(loss.item(), 3), train_acc
178184

179185

0 commit comments

Comments
 (0)