Skip to content

Commit 5daa239

Browse files
committed
polish
Signed-off-by: h-guo18 <[email protected]>
1 parent 1a63254 commit 5daa239

File tree

2 files changed

+25
-42
lines changed

2 files changed

+25
-42
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@
4040
mto.enable_huggingface_checkpointing()
4141

4242
# Hyperparameters for profiling
43-
EPOCHS = 1
4443
LOG_INTERVAL = 100
4544
SAVE_INTERVAL = 20000
46-
# VALIDATE_INTERVAL = 20
4745

4846
# Shape and dtype description of the distillation signal
4947
DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]]
@@ -61,11 +59,11 @@ class BaseDistillTrainer:
6159

6260
def __init__(self, rank, args, tokenizer, dataloader):
6361
self.rank = rank
64-
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
65-
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
6662
self.args = args
6763
self.tokenizer = tokenizer
6864
self.dataloader = dataloader
65+
66+
# Prepare models
6967
if rank in args.student_ranks:
7068
self.model = self.prepare_student_model()
7169
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
@@ -180,7 +178,11 @@ def _get_logging_context(self):
180178
return wandb.init(
181179
entity=os.environ["WANDB_ENTITY"],
182180
project=os.environ["WANDB_PROJECT"],
183-
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
181+
config={
182+
"epochs": self.args.epoch,
183+
"lr": self.args.lr,
184+
"batch_size": self.args.batch_size,
185+
},
184186
)
185187
return nullcontext()
186188

@@ -193,7 +195,7 @@ def train(self):
193195
self._init_student_recv_buffer()
194196

195197
# Student training loop
196-
for epoch in range(EPOCHS):
198+
for epoch in range(self.args.epoch):
197199
pbar = (
198200
tqdm(self.dataloader)
199201
if self.rank == self.args.student_ranks[0]
@@ -236,7 +238,7 @@ def train(self):
236238

237239
else:
238240
# Inference Loop
239-
for epoch in range(EPOCHS):
241+
for epoch in range(self.args.epoch):
240242
for i, batch in enumerate(self.dataloader):
241243
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
242244
with torch.inference_mode():
@@ -390,8 +392,10 @@ def student_step(
390392
) -> ModelOutput:
391393
self.optimizer.zero_grad()
392394

393-
# Chunk inputs for each student rank.
395+
# Chunk input_ids and attention_mask for each student rank.
394396
inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()}
395397

396398
# Second stage forward with provided base model outputs.
397-
return self.model(**inputs, base_model_outputs=distill_msgs)
399+
output = self.model(**inputs, base_model_outputs=distill_msgs)
400+
401+
return output

examples/speculative_decoding/train.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _setup_distributed(rank, args, backend="nccl"):
4141
print(
4242
f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}"
4343
)
44+
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
45+
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
4446

4547

4648
def train(rank, args):
@@ -67,47 +69,24 @@ def train(rank, args):
6769

6870
def main():
6971
parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example")
72+
parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
73+
parser.add_argument("--student_devices", type=list, default=[0, 1, 2, 3])
74+
parser.add_argument("--teacher_devices", type=list, default=[4, 5])
7075
parser.add_argument(
71-
"--model_path",
72-
type=str,
73-
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
74-
help="Path to the model.",
75-
)
76-
parser.add_argument(
77-
"--student_devices", type=list, default=[0, 1, 2, 3], help="Devices for student model"
78-
)
79-
parser.add_argument(
80-
"--teacher_devices", type=list, default=[4, 5], help="Devices for teacher model"
81-
)
82-
parser.add_argument(
83-
"--data_path",
84-
type=str,
85-
default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl",
86-
help="Path to the training data.",
87-
)
88-
parser.add_argument(
89-
"--training_seq_len",
90-
type=str,
91-
default=1024,
92-
help="Training sequence length.",
93-
)
94-
parser.add_argument(
95-
"--eagle_config_path",
96-
type=str,
97-
default="eagle_config.json",
98-
help="Path to the eagle config.",
76+
"--data_path", type=str, default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
9977
)
78+
parser.add_argument("--training_seq_len", type=str, default=1024)
79+
parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json")
10080
parser.add_argument(
10181
"--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing."
10282
)
103-
parser.add_argument(
104-
"--out_path", type=str, default="ckpts/fast-trained", help="Path to save the model."
105-
)
106-
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
83+
parser.add_argument("--out_path", type=str, default="ckpts/fast-trained")
84+
parser.add_argument("--lr", type=float, default=1e-5)
85+
parser.add_argument("--epoch", type=int, default=1)
10786
parser.add_argument(
10887
"--batch_size", type=int, default=4, help="Total batch size across all parallel ranks."
10988
)
110-
parser.add_argument("--master_port", type=str, default="12357", help="Master port.")
89+
parser.add_argument("--master_port", type=str, default="12357")
11190

11291
args = parser.parse_args()
11392
# TODO: add sanity check for args

0 commit comments

Comments
 (0)