From 2c7de10291b4de22a489f889233a00cc8d4d9282 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 22 Sep 2025 00:35:33 +0000 Subject: [PATCH 1/9] add new trainer Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 224 +++++++++++++ examples/speculative_decoding/train.py | 304 ++++++++++++++++++ 2 files changed, 528 insertions(+) create mode 100644 examples/speculative_decoding/distill_trainer.py create mode 100644 examples/speculative_decoding/train.py diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py new file mode 100644 index 000000000..ce6c3a02c --- /dev/null +++ b/examples/speculative_decoding/distill_trainer.py @@ -0,0 +1,224 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +from abc import abstractmethod + +import torch +import torch.distributed as dist +from tqdm import tqdm + +import modelopt.torch.opt as mto + +mto.enable_huggingface_checkpointing() + +# Hyperparameters for profiling +EPOCHS = 20 +LOG_INTERVAL = 25 +SAVE_INTERVAL = 20000 +# VALIDATE_INTERVAL = 20 + +# We define the distill signal from teacher as the map of variable name to its shape and dtype. +DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]] + + +class BaseDistillTrainer: + """ + Base class for distillation trainer. Initalized and called on every rank. + Args: + rank: rank of the current process + args: arguments + teacher_step: teacher step function. + student_step: student step function. + """ + + def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata): + self.rank = rank + args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) + self.args = args + self.tokenizer = tokenizer + self.distill_metadata = distill_metadata + + def _print_model_placement(self, module): + for name, param in module.named_parameters(): + print(f"(Rank {self.rank}) {name} ---> {param.device} ") + + @property + def current_rank_devices(self): + pass + + def _reset_all_mem_stats(self): + for d in self.current_rank_devices: + torch.cuda.reset_max_memory_allocated(d) + + def _print_mem_stats(self): + for d in self.current_rank_devices: + max_mem = torch.cuda.max_memory_allocated(d) + print(f"GPU {d}: Max memory allocated: {max_mem / 1024**3:.2f} GB") + + @abstractmethod + def load_teacher_model(self): + pass + + @abstractmethod + def load_student_model(self): + pass + + @abstractmethod + def teacher_step(self, *args, **kwargs) -> dict[str, torch.Tensor]: + pass + + @abstractmethod + def student_step(self, *args, **kwargs): + pass + + def save_pretrained(self, path=None): + if self.rank == self.args.student_rank: + path = self.args.out_path if path is None else path + self.model.save_pretrained(path) + self.tokenizer.save_pretrained(path) + print(f"Pretrained model saved to {path}") + + def _check_valid_message(self, message: dict[str, torch.Tensor]): + # Check if keys and length match between message and distill_metadata + if set(message.keys()) != set(self.distill_metadata.keys()): + raise ValueError( + f"Message keys from teacher: {set(message.keys())} \n" + f"do not match expected keys {set(self.distill_metadata.keys())}" + ) + if len(message) != len(self.distill_metadata): + raise ValueError( + f"Message length from teacher: {len(message)} \n" + f"does not match expected {len(self.distill_metadata)}" + ) + for k, v in message.items(): + if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]: + raise ValueError( + f"Invalid message from teacher. {k} has shape {v.shape} and dtype {v.dtype}, \n" + f"expected {self.distill_metadata[k]}" + ) + + def _init_student_recv_buffer(self): + self.student_recv_buffer = { + k: torch.empty(v[0], device=self.args.student_device, dtype=v[1]) + for k, v in self.distill_metadata.items() + } + + def _recv_from_teacher(self): + reqs = [ + dist.irecv(buffer, src=self.args.teacher_ranks[0]) + for buffer in self.student_recv_buffer.values() + ] + for req in reqs: + req.wait() + + def _get_distill_kwargs(self): + return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()} + + def _send_to_student(self, teacher_outputs): + if self.rank != self.args.teacher_ranks[0]: + return + self._check_valid_message(teacher_outputs) + reqs = [ + dist.isend(buffer, dst=self.args.student_rank) for buffer in teacher_outputs.values() + ] + for req in reqs: + req.wait() + + # def _validate_ar(self, steps=3, osl=20, num_samples=20): + # if self.rank != self.args.student_rank: + # return + # # Load MT-Bench prompts from HuggingFace + # ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] + # self.model.eval() + # self.model.to(self.args.student_device) + # ars = validate_ar( + # self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device + # ) + # # Print results + # avg_ar = sum(ars) / len(ars) + # print("\n==== AR Validation Results on MT-Bench ====") + # print(f"Number of samples: {len(ars)}") + # print(f"Output Sequence Length: {osl}") + # print(f"Steps: {steps}") + # print(f"Average AR: {avg_ar:.4f}") + # self.model.train() + + def train(self, dataloader): + """Main training entrance of the composed model.""" + self._reset_all_mem_stats() + + if self.rank == self.args.student_rank: + import wandb + + wandb.login() + + with wandb.init( + entity=os.environ["WANDB_ENTITY"], + project=os.environ["WANDB_PROJECT"], + config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, + ) as run: + self.model, self.optimizer = self.load_student_model() + self._init_student_recv_buffer() + wandb.watch(self.model, log="all") + + for epoch in range(EPOCHS): + pbar = tqdm(dataloader) + for i, batch in enumerate(pbar): + global_step = epoch * len(dataloader) + i + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + self._recv_from_teacher() + loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs()) + pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}") + + if global_step % LOG_INTERVAL == 0: + run.log( + { + "loss": loss, + "train_acc_step0": train_acc[0], + "train_acc_step1": train_acc[1], + "train_acc_step2": train_acc[2], + "train_acc_step3": train_acc[3], + }, + step=global_step, + ) + + # This is not working for some reason. + # if global_step > 0 and global_step % VALIDATE_INTERVAL == 0: + # self._validate_ar() + + if global_step > 0 and global_step % SAVE_INTERVAL == 0: + self.save_pretrained( + f"{self.args.out_path}/epoch_{epoch}_step_{global_step}" + ) + + else: + self.model = self.load_teacher_model() + # Inference Loop + for epoch in range(EPOCHS): + for i, batch in enumerate(dataloader): + global_step = epoch * len(dataloader) + i + inputs = {k: v.to(self.model.device) for k, v in batch.items()} + inputs["position_ids"] = None + with torch.inference_mode(): + teacher_outputs = self.teacher_step(self.model, inputs) + self._send_to_student(teacher_outputs) + + self._print_mem_stats() + # Makesure all processes finished before destroy. + dist.barrier() + # clean up processess + dist.destroy_process_group() diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py new file mode 100644 index 000000000..a302711f8 --- /dev/null +++ b/examples/speculative_decoding/train.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from distill_trainer import BaseDistillTrainer +from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module +from torch.distributed.device_mesh import DeviceMesh +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG + +# Hyperparameters for profiling +INPUT_LENGTH = 512 +# DRAFT_VOCAB_SIZE = 128256 +DRAFT_VOCAB_SIZE = 32000 +# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct" +MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.2-1B-Instruct" +# MODEL_PATH = "openai/gpt-oss-20b" +# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.3-70B-Instruct" + + +def _setup_distributed(rank, args, backend="nccl"): + """Initialize distributed environment""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = args.master_port + os.environ["LOCAL_RANK"] = str(rank) + # Initialize process group + dist.init_process_group(backend, rank=rank, world_size=args.world_size) + if rank == args.student_rank: + torch.cuda.set_device(args.student_device) + else: + torch.cuda.set_device(args.teacher_devices[rank - 1]) + print( + f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" + ) + + +class EagleTPTrainer(BaseDistillTrainer): + @property + def current_rank_devices(self): + if self.rank == self.args.student_rank: + return [self.args.student_device] + else: + return [self.args.teacher_devices[self.rank - 1]] + + def load_teacher_model(self): + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype="auto", + tp_plan="auto", + device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), + ) + mtsp.convert(model, [("eagle", self.args.eagle_config)]) + model.eval() + self._print_model_placement(model) + return model + + def load_student_model(self, keep_modules_from_teacher=["embed_tokens", "lm_head"]): + """Load student model on a single device and keep needed modules from teacher.""" + # Load to CPU first to avoid OOM + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cpu" + ) + # Hidden size and vocab size must match base model + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": DRAFT_VOCAB_SIZE, + } + ) + mtsp.convert( + model, + [("eagle", self.args.eagle_config)], + ) + if model.config.vocab_size > DRAFT_VOCAB_SIZE: + model_name = os.path.basename(os.path.normpath(MODEL_PATH)) + vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") + try: + vocab_cache = torch.load(vocab_cache_path) + assert len(vocab_cache) == DRAFT_VOCAB_SIZE + model.eagle_module.d2t = vocab_cache + print(f"Loaded draft vocab cache from {vocab_cache_path}.") + except Exception as e: + raise e + + # We copy needed modules and del the rest + model.eagle_module.to(self.args.student_device) + for name, _ in list(model._modules.items()): + if name in keep_modules_from_teacher: + getattr(model, name).to(self.args.student_device) + + model.train() + optimizer = torch.optim.Adam(model.eagle_module.parameters(), lr=self.args.lr) + self._print_model_placement(model) + return model, optimizer + + def teacher_step(self, model, inputs): + base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( + **inputs, + freeze_base_model=True, + past_key_values=None, + ) + # aux_hidden_states could be on multiple devices. Gather them and cat. + aux_hidden_states = torch.cat( + [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 + ) + return { + "base_model_hidden_states": base_model_hidden_states, + "aux_hidden_states": aux_hidden_states, + "base_model_logits": base_model_logits, + } + + def student_step( + self, + inputs, + base_model_hidden_states, + aux_hidden_states, + base_model_logits, + ): + self.optimizer.zero_grad() + # Second stage forward using the unified model + output = self.model( + **inputs, + # providing base model outputs to bypass the base model forward. + base_model_outputs={ + "base_model_hidden_states": base_model_hidden_states, + "aux_hidden_states": aux_hidden_states.clone().detach(), + "base_model_logits": base_model_logits.clone().detach(), + }, + ) + loss = output.loss + train_acc = output.train_acc + + # Backward + loss.backward() + self.optimizer.step() + return round(loss.item(), 3), train_acc + + +class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): + @property + def current_rank_devices(self): + if self.rank == self.args.student_rank: + return [self.args.student_device] + else: + return self.args.teacher_devices + + def load_teacher_model(self): + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype="auto", + device_map="sequential", + max_memory=dict.fromkeys( + self.args.teacher_devices, "999GiB" + ), # To use only given devices + ) + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": DRAFT_VOCAB_SIZE, + } + ) + mtsp.convert(model, [("eagle", self.args.eagle_config)]) + + if model.config.vocab_size > DRAFT_VOCAB_SIZE: + model_name = os.path.basename(os.path.normpath(MODEL_PATH)) + vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") + try: + vocab_cache = torch.load(vocab_cache_path) + assert len(vocab_cache) == DRAFT_VOCAB_SIZE + model.eagle_module.d2t = vocab_cache + print(f"Loaded draft vocab cache from {vocab_cache_path}.") + except Exception as e: + raise e + + model.eval() + self._print_model_placement(model) + return model + + +def train(rank, args): + _setup_distributed(rank, args) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + data_module = make_eagle_supervised_data_module(tokenizer, args) + + train_dataloader = torch.utils.data.DataLoader( + data_module["train_dataset"], + batch_size=args.batch_size, + shuffle=True, + num_workers=0, + collate_fn=DataCollatorWithPadding(train_length=INPUT_LENGTH), + drop_last=True, + ) + + trainer_cls = { + "tp": EagleTPTrainer, + "mp": EagleMPTrainer, + }[args.teacher_parallel] + + distill_metadata = { + "base_model_hidden_states": ( + torch.Size([args.batch_size, INPUT_LENGTH, 2048]), + torch.bfloat16, + ), + "aux_hidden_states": ( + torch.Size([args.batch_size, INPUT_LENGTH, 2048 * 3]), + torch.bfloat16, + ), + "base_model_logits": ( + torch.Size([args.batch_size, INPUT_LENGTH, DRAFT_VOCAB_SIZE]), + torch.bfloat16, + ), + } + + trainer = trainer_cls(rank, args, tokenizer, distill_metadata) + trainer.train(train_dataloader) + # trainer.save_pretrained("ckpts/fast-trained") + + +def main(): + parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") + + parser.add_argument("--student_device", type=int, default=0, help="Device for student model") + parser.add_argument( + "--teacher_devices", type=list, default=[1], help="Devices for teacher model" + ) + parser.add_argument( + "--teacher_parallel", + type=str, + choices=["tp", "mp"], + default="mp", + help="Parallel type for teacher model. TP and MP supported.", + ) + parser.add_argument( + "--data_path", + type=str, + default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl", + help="Path to the training data.", + ) + parser.add_argument( + "--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing." + ) + parser.add_argument( + "--out_path", type=str, default="ckpts/fast-trained", help="Path to save the model." + ) + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size.") + parser.add_argument("--master_port", type=str, default="12357", help="Master port.") + + args = parser.parse_args() + args.eagle_config = EAGLE3_DEFAULT_CFG["config"] + # TODO: add sanity check for args + + def set_ranks(student_device, teacher_devices, teacher_parallel): + # TODO(hg): add "no-parallel" option, fallback when only one teacher device is provided. + # TODO(hg): add "FSDP" option. + if teacher_parallel == "tp": + world_size = len(teacher_devices) + 1 + student_rank = 0 + teacher_ranks = list(range(1, len(teacher_devices) + 1)) + elif teacher_parallel == "mp": + world_size = 2 + student_rank = 0 + teacher_ranks = [1] + else: + raise NotImplementedError(f"Parallel type {teacher_parallel} not supported.") + return world_size, student_rank, teacher_ranks + + args.world_size, args.student_rank, args.teacher_ranks = set_ranks( + args.student_device, args.teacher_devices, args.teacher_parallel + ) + + # Launch multiple processes + mp.spawn( + train, + args=(args,), + nprocs=args.world_size, + join=True, + ) + + +if __name__ == "__main__": + main() From fa5df91e57e18cf35610e35b0090912e22c88a77 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 27 Sep 2025 00:58:59 +0000 Subject: [PATCH 2/9] add student model ddp Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 43 ++-- examples/speculative_decoding/train.py | 190 ++++++++++-------- 2 files changed, 132 insertions(+), 101 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index ce6c3a02c..4fbeea221 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -26,8 +26,8 @@ mto.enable_huggingface_checkpointing() # Hyperparameters for profiling -EPOCHS = 20 -LOG_INTERVAL = 25 +EPOCHS = 1 +LOG_INTERVAL = 1 SAVE_INTERVAL = 20000 # VALIDATE_INTERVAL = 20 @@ -48,6 +48,7 @@ class BaseDistillTrainer: def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata): self.rank = rank args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) + args.student_pgroup = dist.new_group(ranks=args.student_ranks) self.args = args self.tokenizer = tokenizer self.distill_metadata = distill_metadata @@ -57,17 +58,15 @@ def _print_model_placement(self, module): print(f"(Rank {self.rank}) {name} ---> {param.device} ") @property - def current_rank_devices(self): + def current_rank_device(self): pass def _reset_all_mem_stats(self): - for d in self.current_rank_devices: - torch.cuda.reset_max_memory_allocated(d) + torch.cuda.reset_max_memory_allocated(self.current_rank_device) def _print_mem_stats(self): - for d in self.current_rank_devices: - max_mem = torch.cuda.max_memory_allocated(d) - print(f"GPU {d}: Max memory allocated: {max_mem / 1024**3:.2f} GB") + max_mem = torch.cuda.max_memory_allocated(self.current_rank_device) + print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB") @abstractmethod def load_teacher_model(self): @@ -86,7 +85,7 @@ def student_step(self, *args, **kwargs): pass def save_pretrained(self, path=None): - if self.rank == self.args.student_rank: + if self.rank == self.args.student_ranks[0]: path = self.args.out_path if path is None else path self.model.save_pretrained(path) self.tokenizer.save_pretrained(path) @@ -96,24 +95,24 @@ def _check_valid_message(self, message: dict[str, torch.Tensor]): # Check if keys and length match between message and distill_metadata if set(message.keys()) != set(self.distill_metadata.keys()): raise ValueError( - f"Message keys from teacher: {set(message.keys())} \n" + f"Message keys: {set(message.keys())} \n" f"do not match expected keys {set(self.distill_metadata.keys())}" ) if len(message) != len(self.distill_metadata): raise ValueError( - f"Message length from teacher: {len(message)} \n" + f"Message length: {len(message)} \n" f"does not match expected {len(self.distill_metadata)}" ) for k, v in message.items(): if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]: raise ValueError( - f"Invalid message from teacher. {k} has shape {v.shape} and dtype {v.dtype}, \n" + f"Invalid message. {k} has shape {v.shape} and dtype {v.dtype}, \n" f"expected {self.distill_metadata[k]}" ) def _init_student_recv_buffer(self): self.student_recv_buffer = { - k: torch.empty(v[0], device=self.args.student_device, dtype=v[1]) + k: torch.empty(v[0], device=self.current_rank_device, dtype=v[1]) for k, v in self.distill_metadata.items() } @@ -131,12 +130,16 @@ def _get_distill_kwargs(self): def _send_to_student(self, teacher_outputs): if self.rank != self.args.teacher_ranks[0]: return - self._check_valid_message(teacher_outputs) - reqs = [ - dist.isend(buffer, dst=self.args.student_rank) for buffer in teacher_outputs.values() - ] - for req in reqs: - req.wait() + # TODO: use broadcast + assert len(teacher_outputs) == len(self.args.student_ranks), ( + f"Number of teacher outputs {len(teacher_outputs)} does not \ + match number of student ranks {len(self.args.student_ranks)}" + ) + for s in self.args.student_ranks: + self._check_valid_message(teacher_outputs[s]) + reqs = [dist.isend(buffer, dst=s) for buffer in teacher_outputs[s].values()] + for req in reqs: + req.wait() # def _validate_ar(self, steps=3, osl=20, num_samples=20): # if self.rank != self.args.student_rank: @@ -161,7 +164,7 @@ def train(self, dataloader): """Main training entrance of the composed model.""" self._reset_all_mem_stats() - if self.rank == self.args.student_rank: + if self.rank in self.args.student_ranks: import wandb wandb.login() diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index a302711f8..3bc1c09cd 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -28,9 +28,10 @@ from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG # Hyperparameters for profiling +torch.manual_seed(0) INPUT_LENGTH = 512 -# DRAFT_VOCAB_SIZE = 128256 -DRAFT_VOCAB_SIZE = 32000 +DRAFT_VOCAB_SIZE = 128256 +# DRAFT_VOCAB_SIZE = 32000 # MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct" MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.2-1B-Instruct" # MODEL_PATH = "openai/gpt-oss-20b" @@ -44,10 +45,10 @@ def _setup_distributed(rank, args, backend="nccl"): os.environ["LOCAL_RANK"] = str(rank) # Initialize process group dist.init_process_group(backend, rank=rank, world_size=args.world_size) - if rank == args.student_rank: - torch.cuda.set_device(args.student_device) + if rank in args.student_ranks: + torch.cuda.set_device(args.student_devices[rank]) else: - torch.cuda.set_device(args.teacher_devices[rank - 1]) + torch.cuda.set_device(args.teacher_devices[rank - len(args.student_ranks)]) print( f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" ) @@ -55,11 +56,11 @@ def _setup_distributed(rank, args, backend="nccl"): class EagleTPTrainer(BaseDistillTrainer): @property - def current_rank_devices(self): - if self.rank == self.args.student_rank: - return [self.args.student_device] + def current_rank_device(self): + if self.rank in self.args.student_ranks: + return self.args.student_devices[self.rank] else: - return [self.args.teacher_devices[self.rank - 1]] + return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] def load_teacher_model(self): model = AutoModelForCausalLM.from_pretrained( @@ -68,12 +69,19 @@ def load_teacher_model(self): tp_plan="auto", device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), ) + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": DRAFT_VOCAB_SIZE, + } + ) mtsp.convert(model, [("eagle", self.args.eagle_config)]) model.eval() self._print_model_placement(model) return model - def load_student_model(self, keep_modules_from_teacher=["embed_tokens", "lm_head"]): + def load_student_model(self): """Load student model on a single device and keep needed modules from teacher.""" # Load to CPU first to avoid OOM model = AutoModelForCausalLM.from_pretrained( @@ -102,14 +110,18 @@ def load_student_model(self, keep_modules_from_teacher=["embed_tokens", "lm_head except Exception as e: raise e - # We copy needed modules and del the rest - model.eagle_module.to(self.args.student_device) - for name, _ in list(model._modules.items()): - if name in keep_modules_from_teacher: - getattr(model, name).to(self.args.student_device) + # TODO:copy needed modules and del the rest + model.model._modules.pop("layers") + model.to(self.current_rank_device) model.train() - optimizer = torch.optim.Adam(model.eagle_module.parameters(), lr=self.args.lr) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.current_rank_device], + process_group=self.args.student_pgroup, + find_unused_parameters=True, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr) self._print_model_placement(model) return model, optimizer @@ -123,11 +135,18 @@ def teacher_step(self, model, inputs): aux_hidden_states = torch.cat( [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 ) - return { - "base_model_hidden_states": base_model_hidden_states, - "aux_hidden_states": aux_hidden_states, - "base_model_logits": base_model_logits, - } + base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) + base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) + aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) + + return [ + { + "base_model_hidden_states": base_model_hidden_states[i], + "aux_hidden_states": aux_hidden_states[i], + "base_model_logits": base_model_logits[i], + } + for i in range(len(self.args.student_ranks)) + ] def student_step( self, @@ -138,6 +157,7 @@ def student_step( ): self.optimizer.zero_grad() # Second stage forward using the unified model + inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} output = self.model( **inputs, # providing base model outputs to bypass the base model forward. @@ -148,6 +168,7 @@ def student_step( }, ) loss = output.loss + print(f"Rank {self.rank} loss: {loss.item()}") train_acc = output.train_acc # Backward @@ -156,79 +177,81 @@ def student_step( return round(loss.item(), 3), train_acc -class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): - @property - def current_rank_devices(self): - if self.rank == self.args.student_rank: - return [self.args.student_device] - else: - return self.args.teacher_devices - - def load_teacher_model(self): - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - torch_dtype="auto", - device_map="sequential", - max_memory=dict.fromkeys( - self.args.teacher_devices, "999GiB" - ), # To use only given devices - ) - self.args.eagle_config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "draft_vocab_size": DRAFT_VOCAB_SIZE, - } - ) - mtsp.convert(model, [("eagle", self.args.eagle_config)]) - - if model.config.vocab_size > DRAFT_VOCAB_SIZE: - model_name = os.path.basename(os.path.normpath(MODEL_PATH)) - vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") - try: - vocab_cache = torch.load(vocab_cache_path) - assert len(vocab_cache) == DRAFT_VOCAB_SIZE - model.eagle_module.d2t = vocab_cache - print(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e - - model.eval() - self._print_model_placement(model) - return model +# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): +# @property +# def current_rank_devices(self): +# if self.rank == self.args.student_rank: +# return [self.args.student_device] +# else: +# return self.args.teacher_devices + +# def load_teacher_model(self): +# model = AutoModelForCausalLM.from_pretrained( +# MODEL_PATH, +# torch_dtype="auto", +# device_map="sequential", +# max_memory=dict.fromkeys( +# self.args.teacher_devices, "999GiB" +# ), # To use only given devices +# ) +# self.args.eagle_config["eagle_architecture_config"].update( +# { +# "hidden_size": model.config.hidden_size, +# "vocab_size": model.config.vocab_size, +# "draft_vocab_size": DRAFT_VOCAB_SIZE, +# } +# ) +# mtsp.convert(model, [("eagle", self.args.eagle_config)]) + +# if model.config.vocab_size > DRAFT_VOCAB_SIZE: +# model_name = os.path.basename(os.path.normpath(MODEL_PATH)) +# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") +# try: +# vocab_cache = torch.load(vocab_cache_path) +# assert len(vocab_cache) == DRAFT_VOCAB_SIZE +# model.eagle_module.d2t = vocab_cache +# print(f"Loaded draft vocab cache from {vocab_cache_path}.") +# except Exception as e: +# raise e + +# model.eval() +# self._print_model_placement(model) +# return model def train(rank, args): _setup_distributed(rank, args) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) - data_module = make_eagle_supervised_data_module(tokenizer, args) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, model_max_length=INPUT_LENGTH) + data_module = make_eagle_supervised_data_module(tokenizer, args, use_offline_training=False) train_dataloader = torch.utils.data.DataLoader( data_module["train_dataset"], batch_size=args.batch_size, shuffle=True, num_workers=0, - collate_fn=DataCollatorWithPadding(train_length=INPUT_LENGTH), + collate_fn=DataCollatorWithPadding(max_length=INPUT_LENGTH), drop_last=True, ) trainer_cls = { "tp": EagleTPTrainer, - "mp": EagleMPTrainer, + # "mp": EagleMPTrainer, }[args.teacher_parallel] distill_metadata = { "base_model_hidden_states": ( - torch.Size([args.batch_size, INPUT_LENGTH, 2048]), + torch.Size([int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, 2048]), torch.bfloat16, ), "aux_hidden_states": ( - torch.Size([args.batch_size, INPUT_LENGTH, 2048 * 3]), + torch.Size([int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, 2048 * 3]), torch.bfloat16, ), "base_model_logits": ( - torch.Size([args.batch_size, INPUT_LENGTH, DRAFT_VOCAB_SIZE]), + torch.Size( + [int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, DRAFT_VOCAB_SIZE] + ), torch.bfloat16, ), } @@ -241,15 +264,17 @@ def train(rank, args): def main(): parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") - parser.add_argument("--student_device", type=int, default=0, help="Device for student model") parser.add_argument( - "--teacher_devices", type=list, default=[1], help="Devices for teacher model" + "--student_devices", type=list, default=[0, 1, 2, 3], help="Devices for student model" + ) + parser.add_argument( + "--teacher_devices", type=list, default=[4, 5], help="Devices for teacher model" ) parser.add_argument( "--teacher_parallel", type=str, choices=["tp", "mp"], - default="mp", + default="tp", help="Parallel type for teacher model. TP and MP supported.", ) parser.add_argument( @@ -272,23 +297,26 @@ def main(): args.eagle_config = EAGLE3_DEFAULT_CFG["config"] # TODO: add sanity check for args - def set_ranks(student_device, teacher_devices, teacher_parallel): + def set_ranks(student_devices, teacher_devices, teacher_parallel): # TODO(hg): add "no-parallel" option, fallback when only one teacher device is provided. # TODO(hg): add "FSDP" option. if teacher_parallel == "tp": - world_size = len(teacher_devices) + 1 - student_rank = 0 - teacher_ranks = list(range(1, len(teacher_devices) + 1)) + world_size = len(teacher_devices) + len(student_devices) + student_ranks = list(range(len(student_devices))) + teacher_ranks = list( + range(len(student_devices), len(student_devices) + len(teacher_devices)) + ) elif teacher_parallel == "mp": - world_size = 2 - student_rank = 0 - teacher_ranks = [1] + raise NotImplementedError("MP parallel type not supported.") + # world_size = 2 + # student_rank = 0 + # teacher_ranks = [1] else: raise NotImplementedError(f"Parallel type {teacher_parallel} not supported.") - return world_size, student_rank, teacher_ranks + return world_size, student_ranks, teacher_ranks - args.world_size, args.student_rank, args.teacher_ranks = set_ranks( - args.student_device, args.teacher_devices, args.teacher_parallel + args.world_size, args.student_ranks, args.teacher_ranks = set_ranks( + args.student_devices, args.teacher_devices, args.teacher_parallel ) # Launch multiple processes From 895ceaf406f34b39c4444de2c5b325135a6c39b7 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:26:25 +0000 Subject: [PATCH 3/9] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 41 ++++++------------- examples/speculative_decoding/train.py | 20 +++++---- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index 4fbeea221..b12f5b65e 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -26,8 +26,8 @@ mto.enable_huggingface_checkpointing() # Hyperparameters for profiling -EPOCHS = 1 -LOG_INTERVAL = 1 +EPOCHS = 10 +LOG_INTERVAL = 100 SAVE_INTERVAL = 20000 # VALIDATE_INTERVAL = 20 @@ -125,6 +125,7 @@ def _recv_from_teacher(self): req.wait() def _get_distill_kwargs(self): + """Return a copy of received buffer for student training.""" return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()} def _send_to_student(self, teacher_outputs): @@ -141,25 +142,6 @@ def _send_to_student(self, teacher_outputs): for req in reqs: req.wait() - # def _validate_ar(self, steps=3, osl=20, num_samples=20): - # if self.rank != self.args.student_rank: - # return - # # Load MT-Bench prompts from HuggingFace - # ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] - # self.model.eval() - # self.model.to(self.args.student_device) - # ars = validate_ar( - # self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device - # ) - # # Print results - # avg_ar = sum(ars) / len(ars) - # print("\n==== AR Validation Results on MT-Bench ====") - # print(f"Number of samples: {len(ars)}") - # print(f"Output Sequence Length: {osl}") - # print(f"Steps: {steps}") - # print(f"Average AR: {avg_ar:.4f}") - # self.model.train() - def train(self, dataloader): """Main training entrance of the composed model.""" self._reset_all_mem_stats() @@ -174,19 +156,24 @@ def train(self, dataloader): project=os.environ["WANDB_PROJECT"], config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, ) as run: - self.model, self.optimizer = self.load_student_model() + self.model, self.optimizer, self.scheduler = self.load_student_model() self._init_student_recv_buffer() wandb.watch(self.model, log="all") for epoch in range(EPOCHS): - pbar = tqdm(dataloader) + pbar = ( + tqdm(dataloader) if self.rank == self.args.student_ranks[0] else dataloader + ) for i, batch in enumerate(pbar): global_step = epoch * len(dataloader) + i inputs = {k: v.to(self.model.device) for k, v in batch.items()} self._recv_from_teacher() loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs()) - pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}") + if self.rank != self.args.student_ranks[0]: + continue + + pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}") if global_step % LOG_INTERVAL == 0: run.log( { @@ -195,14 +182,10 @@ def train(self, dataloader): "train_acc_step1": train_acc[1], "train_acc_step2": train_acc[2], "train_acc_step3": train_acc[3], + "lr": self.optimizer.param_groups[0]["lr"], }, step=global_step, ) - - # This is not working for some reason. - # if global_step > 0 and global_step % VALIDATE_INTERVAL == 0: - # self._validate_ar() - if global_step > 0 and global_step % SAVE_INTERVAL == 0: self.save_pretrained( f"{self.args.out_path}/epoch_{epoch}_step_{global_step}" diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index 3bc1c09cd..dcd0dce34 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -23,17 +23,19 @@ from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module from torch.distributed.device_mesh import DeviceMesh from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.optimization import get_linear_schedule_with_warmup import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG # Hyperparameters for profiling torch.manual_seed(0) -INPUT_LENGTH = 512 -DRAFT_VOCAB_SIZE = 128256 -# DRAFT_VOCAB_SIZE = 32000 +INPUT_LENGTH = 1024 +# DRAFT_VOCAB_SIZE = 128256 +DRAFT_VOCAB_SIZE = 32000 # MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct" -MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.2-1B-Instruct" +# MODEL_PATH = "/lustre/fsw/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local/meta-llama/Llama-3.2-1B-Instruct" +MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # MODEL_PATH = "openai/gpt-oss-20b" # MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.3-70B-Instruct" @@ -121,9 +123,12 @@ def load_student_model(self): process_group=self.args.student_pgroup, find_unused_parameters=True, ) - optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=0, num_training_steps=117380 + ) self._print_model_placement(model) - return model, optimizer + return model, optimizer, scheduler def teacher_step(self, model, inputs): base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( @@ -168,12 +173,13 @@ def student_step( }, ) loss = output.loss - print(f"Rank {self.rank} loss: {loss.item()}") + # print(f"Rank {self.rank} loss: {loss.item()}") train_acc = output.train_acc # Backward loss.backward() self.optimizer.step() + self.scheduler.step() return round(loss.item(), 3), train_acc From 6c622072e3aaea540f94138f410fc27c56d853ff Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Oct 2025 21:27:06 +0000 Subject: [PATCH 4/9] refactor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 179 ++++++++++++++++- examples/speculative_decoding/train.py | 180 +----------------- 2 files changed, 181 insertions(+), 178 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index b12f5b65e..ad968a891 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -19,19 +19,25 @@ import torch import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh from tqdm import tqdm +from transformers import AutoModelForCausalLM +from transformers.optimization import get_linear_schedule_with_warmup import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp mto.enable_huggingface_checkpointing() # Hyperparameters for profiling -EPOCHS = 10 +EPOCHS = 1 LOG_INTERVAL = 100 SAVE_INTERVAL = 20000 +MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +DRAFT_VOCAB_SIZE = 32000 # VALIDATE_INTERVAL = 20 -# We define the distill signal from teacher as the map of variable name to its shape and dtype. +# Shape and dtype description of the distillation signal DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]] @@ -208,3 +214,172 @@ def train(self, dataloader): dist.barrier() # clean up processess dist.destroy_process_group() + + +class EagleTPTrainer(BaseDistillTrainer): + @property + def current_rank_device(self): + if self.rank in self.args.student_ranks: + return self.args.student_devices[self.rank] + else: + return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] + + def load_teacher_model(self): + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + torch_dtype="auto", + tp_plan="auto", + device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), + ) + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": DRAFT_VOCAB_SIZE, + } + ) + mtsp.convert(model, [("eagle", self.args.eagle_config)]) + model.eval() + self._print_model_placement(model) + return model + + def load_student_model(self): + """Load student model on a single device and keep needed modules from teacher.""" + # Load to CPU first to avoid OOM + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cpu" + ) + # Hidden size and vocab size must match base model + self.args.eagle_config["eagle_architecture_config"].update( + { + "hidden_size": model.config.hidden_size, + "vocab_size": model.config.vocab_size, + "draft_vocab_size": DRAFT_VOCAB_SIZE, + } + ) + mtsp.convert( + model, + [("eagle", self.args.eagle_config)], + ) + if model.config.vocab_size > DRAFT_VOCAB_SIZE: + model_name = os.path.basename(os.path.normpath(MODEL_PATH)) + vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") + try: + vocab_cache = torch.load(vocab_cache_path) + assert len(vocab_cache) == DRAFT_VOCAB_SIZE + model.eagle_module.d2t = vocab_cache + print(f"Loaded draft vocab cache from {vocab_cache_path}.") + except Exception as e: + raise e + + # TODO:copy needed modules and del the rest + model.model._modules.pop("layers") + model.to(self.current_rank_device) + + model.train() + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.current_rank_device], + process_group=self.args.student_pgroup, + find_unused_parameters=True, + ) + optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=0, num_training_steps=117380 + ) + self._print_model_placement(model) + return model, optimizer, scheduler + + def teacher_step(self, model, inputs): + base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( + **inputs, + freeze_base_model=True, + past_key_values=None, + ) + # aux_hidden_states could be on multiple devices. Gather them and cat. + aux_hidden_states = torch.cat( + [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 + ) + base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) + base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) + aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) + + return [ + { + "base_model_hidden_states": base_model_hidden_states[i], + "aux_hidden_states": aux_hidden_states[i], + "base_model_logits": base_model_logits[i], + } + for i in range(len(self.args.student_ranks)) + ] + + def student_step( + self, + inputs, + base_model_hidden_states, + aux_hidden_states, + base_model_logits, + ): + self.optimizer.zero_grad() + # Second stage forward using the unified model + inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} + output = self.model( + **inputs, + # providing base model outputs to bypass the base model forward. + base_model_outputs={ + "base_model_hidden_states": base_model_hidden_states, + "aux_hidden_states": aux_hidden_states.clone().detach(), + "base_model_logits": base_model_logits.clone().detach(), + }, + ) + loss = output.loss + # print(f"Rank {self.rank} loss: {loss.item()}") + train_acc = output.train_acc + + # Backward + loss.backward() + self.optimizer.step() + self.scheduler.step() + return round(loss.item(), 3), train_acc + + +# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): +# @property +# def current_rank_devices(self): +# if self.rank == self.args.student_rank: +# return [self.args.student_device] +# else: +# return self.args.teacher_devices + +# def load_teacher_model(self): +# model = AutoModelForCausalLM.from_pretrained( +# MODEL_PATH, +# torch_dtype="auto", +# device_map="sequential", +# max_memory=dict.fromkeys( +# self.args.teacher_devices, "999GiB" +# ), # To use only given devices +# ) +# self.args.eagle_config["eagle_architecture_config"].update( +# { +# "hidden_size": model.config.hidden_size, +# "vocab_size": model.config.vocab_size, +# "draft_vocab_size": DRAFT_VOCAB_SIZE, +# } +# ) +# mtsp.convert(model, [("eagle", self.args.eagle_config)]) + +# if model.config.vocab_size > DRAFT_VOCAB_SIZE: +# model_name = os.path.basename(os.path.normpath(MODEL_PATH)) +# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") +# try: +# vocab_cache = torch.load(vocab_cache_path) +# assert len(vocab_cache) == DRAFT_VOCAB_SIZE +# model.eagle_module.d2t = vocab_cache +# print(f"Loaded draft vocab cache from {vocab_cache_path}.") +# except Exception as e: +# raise e + +# model.eval() +# self._print_model_placement(model) +# return model diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index dcd0dce34..ad19f16c5 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -19,23 +19,20 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from distill_trainer import BaseDistillTrainer +from distill_trainer import DRAFT_VOCAB_SIZE, MODEL_PATH, EagleTPTrainer from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module -from torch.distributed.device_mesh import DeviceMesh -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.optimization import get_linear_schedule_with_warmup +from transformers import AutoTokenizer -import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG # Hyperparameters for profiling torch.manual_seed(0) INPUT_LENGTH = 1024 # DRAFT_VOCAB_SIZE = 128256 -DRAFT_VOCAB_SIZE = 32000 +# DRAFT_VOCAB_SIZE = 32000 # MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct" # MODEL_PATH = "/lustre/fsw/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local/meta-llama/Llama-3.2-1B-Instruct" -MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +# MODEL_PATH ="TinyLlama/TinyLlama-1.1B-Chat-v1.0" # MODEL_PATH = "openai/gpt-oss-20b" # MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.3-70B-Instruct" @@ -56,175 +53,6 @@ def _setup_distributed(rank, args, backend="nccl"): ) -class EagleTPTrainer(BaseDistillTrainer): - @property - def current_rank_device(self): - if self.rank in self.args.student_ranks: - return self.args.student_devices[self.rank] - else: - return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] - - def load_teacher_model(self): - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - torch_dtype="auto", - tp_plan="auto", - device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), - ) - self.args.eagle_config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "draft_vocab_size": DRAFT_VOCAB_SIZE, - } - ) - mtsp.convert(model, [("eagle", self.args.eagle_config)]) - model.eval() - self._print_model_placement(model) - return model - - def load_student_model(self): - """Load student model on a single device and keep needed modules from teacher.""" - # Load to CPU first to avoid OOM - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, torch_dtype="auto", device_map="cpu" - ) - # Hidden size and vocab size must match base model - self.args.eagle_config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "draft_vocab_size": DRAFT_VOCAB_SIZE, - } - ) - mtsp.convert( - model, - [("eagle", self.args.eagle_config)], - ) - if model.config.vocab_size > DRAFT_VOCAB_SIZE: - model_name = os.path.basename(os.path.normpath(MODEL_PATH)) - vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") - try: - vocab_cache = torch.load(vocab_cache_path) - assert len(vocab_cache) == DRAFT_VOCAB_SIZE - model.eagle_module.d2t = vocab_cache - print(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e - - # TODO:copy needed modules and del the rest - model.model._modules.pop("layers") - model.to(self.current_rank_device) - - model.train() - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[self.current_rank_device], - process_group=self.args.student_pgroup, - find_unused_parameters=True, - ) - optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr) - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=0, num_training_steps=117380 - ) - self._print_model_placement(model) - return model, optimizer, scheduler - - def teacher_step(self, model, inputs): - base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( - **inputs, - freeze_base_model=True, - past_key_values=None, - ) - # aux_hidden_states could be on multiple devices. Gather them and cat. - aux_hidden_states = torch.cat( - [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 - ) - base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) - base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) - aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) - - return [ - { - "base_model_hidden_states": base_model_hidden_states[i], - "aux_hidden_states": aux_hidden_states[i], - "base_model_logits": base_model_logits[i], - } - for i in range(len(self.args.student_ranks)) - ] - - def student_step( - self, - inputs, - base_model_hidden_states, - aux_hidden_states, - base_model_logits, - ): - self.optimizer.zero_grad() - # Second stage forward using the unified model - inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} - output = self.model( - **inputs, - # providing base model outputs to bypass the base model forward. - base_model_outputs={ - "base_model_hidden_states": base_model_hidden_states, - "aux_hidden_states": aux_hidden_states.clone().detach(), - "base_model_logits": base_model_logits.clone().detach(), - }, - ) - loss = output.loss - # print(f"Rank {self.rank} loss: {loss.item()}") - train_acc = output.train_acc - - # Backward - loss.backward() - self.optimizer.step() - self.scheduler.step() - return round(loss.item(), 3), train_acc - - -# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): -# @property -# def current_rank_devices(self): -# if self.rank == self.args.student_rank: -# return [self.args.student_device] -# else: -# return self.args.teacher_devices - -# def load_teacher_model(self): -# model = AutoModelForCausalLM.from_pretrained( -# MODEL_PATH, -# torch_dtype="auto", -# device_map="sequential", -# max_memory=dict.fromkeys( -# self.args.teacher_devices, "999GiB" -# ), # To use only given devices -# ) -# self.args.eagle_config["eagle_architecture_config"].update( -# { -# "hidden_size": model.config.hidden_size, -# "vocab_size": model.config.vocab_size, -# "draft_vocab_size": DRAFT_VOCAB_SIZE, -# } -# ) -# mtsp.convert(model, [("eagle", self.args.eagle_config)]) - -# if model.config.vocab_size > DRAFT_VOCAB_SIZE: -# model_name = os.path.basename(os.path.normpath(MODEL_PATH)) -# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") -# try: -# vocab_cache = torch.load(vocab_cache_path) -# assert len(vocab_cache) == DRAFT_VOCAB_SIZE -# model.eagle_module.d2t = vocab_cache -# print(f"Loaded draft vocab cache from {vocab_cache_path}.") -# except Exception as e: -# raise e - -# model.eval() -# self._print_model_placement(model) -# return model - - def train(rank, args): _setup_distributed(rank, args) From 5ae447966eb806b575a5962fa9e95e9be4e6d635 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 2 Oct 2025 23:03:40 +0000 Subject: [PATCH 5/9] refactor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 140 +++++++++--------- examples/speculative_decoding/train.py | 98 ++++-------- 2 files changed, 101 insertions(+), 137 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index ad968a891..31eebd924 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -26,6 +27,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG mto.enable_huggingface_checkpointing() @@ -33,8 +35,6 @@ EPOCHS = 1 LOG_INTERVAL = 100 SAVE_INTERVAL = 20000 -MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -DRAFT_VOCAB_SIZE = 32000 # VALIDATE_INTERVAL = 20 # Shape and dtype description of the distillation signal @@ -51,13 +51,21 @@ class BaseDistillTrainer: student_step: student step function. """ - def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata): + def __init__(self, rank, args, tokenizer): self.rank = rank args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) args.student_pgroup = dist.new_group(ranks=args.student_ranks) self.args = args self.tokenizer = tokenizer - self.distill_metadata = distill_metadata + if rank in args.student_ranks: + self.model = self.prepare_student_model() + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps=0, num_training_steps=117380 + ) + else: + self.model = self.prepare_teacher_model() + self._print_model_placement(self.model) def _print_model_placement(self, module): for name, param in module.named_parameters(): @@ -67,6 +75,10 @@ def _print_model_placement(self, module): def current_rank_device(self): pass + @property + def distill_metadata(self): + pass + def _reset_all_mem_stats(self): torch.cuda.reset_max_memory_allocated(self.current_rank_device) @@ -162,7 +174,6 @@ def train(self, dataloader): project=os.environ["WANDB_PROJECT"], config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, ) as run: - self.model, self.optimizer, self.scheduler = self.load_student_model() self._init_student_recv_buffer() wandb.watch(self.model, log="all") @@ -198,7 +209,6 @@ def train(self, dataloader): ) else: - self.model = self.load_teacher_model() # Inference Loop for epoch in range(EPOCHS): for i, batch in enumerate(dataloader): @@ -217,6 +227,15 @@ def train(self, dataloader): class EagleTPTrainer(BaseDistillTrainer): + def __init__(self, rank, args, tokenizer): + args.eagle_config = EAGLE3_DEFAULT_CFG["config"] + if args.eagle_config_path: + with open(args.eagle_config_path) as f: + custom_config = json.load(f) + args.eagle_config["eagle_architecture_config"].update(custom_config) + + super().__init__(rank, args, tokenizer) + @property def current_rank_device(self): if self.rank in self.args.student_ranks: @@ -224,9 +243,44 @@ def current_rank_device(self): else: return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] - def load_teacher_model(self): + @property + def distill_metadata(self) -> DistillMetadata: + return { + "base_model_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + 2048, + ] + ), + torch.bfloat16, + ), + "aux_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + 2048 * 3, + ] + ), + torch.bfloat16, + ), + "base_model_logits": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.draft_vocab_size, + ] + ), + torch.bfloat16, + ), + } + + def prepare_teacher_model(self): model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, + self.args.model_path, torch_dtype="auto", tp_plan="auto", device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), @@ -235,42 +289,33 @@ def load_teacher_model(self): { "hidden_size": model.config.hidden_size, "vocab_size": model.config.vocab_size, - "draft_vocab_size": DRAFT_VOCAB_SIZE, + "draft_vocab_size": model.config.vocab_size, } ) + self.args.draft_vocab_size = model.config.vocab_size mtsp.convert(model, [("eagle", self.args.eagle_config)]) model.eval() - self._print_model_placement(model) return model - def load_student_model(self): + def prepare_student_model(self): """Load student model on a single device and keep needed modules from teacher.""" # Load to CPU first to avoid OOM model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, torch_dtype="auto", device_map="cpu" + self.args.model_path, torch_dtype="auto", device_map="cpu" ) # Hidden size and vocab size must match base model self.args.eagle_config["eagle_architecture_config"].update( { "hidden_size": model.config.hidden_size, "vocab_size": model.config.vocab_size, - "draft_vocab_size": DRAFT_VOCAB_SIZE, + "draft_vocab_size": model.config.vocab_size, } ) + self.args.draft_vocab_size = model.config.vocab_size mtsp.convert( model, [("eagle", self.args.eagle_config)], ) - if model.config.vocab_size > DRAFT_VOCAB_SIZE: - model_name = os.path.basename(os.path.normpath(MODEL_PATH)) - vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") - try: - vocab_cache = torch.load(vocab_cache_path) - assert len(vocab_cache) == DRAFT_VOCAB_SIZE - model.eagle_module.d2t = vocab_cache - print(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e # TODO:copy needed modules and del the rest model.model._modules.pop("layers") @@ -283,12 +328,7 @@ def load_student_model(self): process_group=self.args.student_pgroup, find_unused_parameters=True, ) - optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr) - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=0, num_training_steps=117380 - ) - self._print_model_placement(model) - return model, optimizer, scheduler + return model def teacher_step(self, model, inputs): base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( @@ -341,45 +381,3 @@ def student_step( self.optimizer.step() self.scheduler.step() return round(loss.item(), 3), train_acc - - -# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): -# @property -# def current_rank_devices(self): -# if self.rank == self.args.student_rank: -# return [self.args.student_device] -# else: -# return self.args.teacher_devices - -# def load_teacher_model(self): -# model = AutoModelForCausalLM.from_pretrained( -# MODEL_PATH, -# torch_dtype="auto", -# device_map="sequential", -# max_memory=dict.fromkeys( -# self.args.teacher_devices, "999GiB" -# ), # To use only given devices -# ) -# self.args.eagle_config["eagle_architecture_config"].update( -# { -# "hidden_size": model.config.hidden_size, -# "vocab_size": model.config.vocab_size, -# "draft_vocab_size": DRAFT_VOCAB_SIZE, -# } -# ) -# mtsp.convert(model, [("eagle", self.args.eagle_config)]) - -# if model.config.vocab_size > DRAFT_VOCAB_SIZE: -# model_name = os.path.basename(os.path.normpath(MODEL_PATH)) -# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") -# try: -# vocab_cache = torch.load(vocab_cache_path) -# assert len(vocab_cache) == DRAFT_VOCAB_SIZE -# model.eagle_module.d2t = vocab_cache -# print(f"Loaded draft vocab cache from {vocab_cache_path}.") -# except Exception as e: -# raise e - -# model.eval() -# self._print_model_placement(model) -# return model diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index ad19f16c5..4a01333ce 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -19,22 +19,12 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from distill_trainer import DRAFT_VOCAB_SIZE, MODEL_PATH, EagleTPTrainer +from distill_trainer import EagleTPTrainer from eagle_utils import DataCollatorWithPadding, make_eagle_supervised_data_module from transformers import AutoTokenizer -from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG - # Hyperparameters for profiling torch.manual_seed(0) -INPUT_LENGTH = 1024 -# DRAFT_VOCAB_SIZE = 128256 -# DRAFT_VOCAB_SIZE = 32000 -# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.1-8B-Instruct" -# MODEL_PATH = "/lustre/fsw/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local/meta-llama/Llama-3.2-1B-Instruct" -# MODEL_PATH ="TinyLlama/TinyLlama-1.1B-Chat-v1.0" -# MODEL_PATH = "openai/gpt-oss-20b" -# MODEL_PATH = "/home/scratch.omniml_data_1/models_ci/meta-llama/Llama-3.3-70B-Instruct" def _setup_distributed(rank, args, backend="nccl"): @@ -56,7 +46,9 @@ def _setup_distributed(rank, args, backend="nccl"): def train(rank, args): _setup_distributed(rank, args) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, model_max_length=INPUT_LENGTH) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, model_max_length=args.training_seq_len + ) data_module = make_eagle_supervised_data_module(tokenizer, args, use_offline_training=False) train_dataloader = torch.utils.data.DataLoader( @@ -64,52 +56,28 @@ def train(rank, args): batch_size=args.batch_size, shuffle=True, num_workers=0, - collate_fn=DataCollatorWithPadding(max_length=INPUT_LENGTH), + collate_fn=DataCollatorWithPadding(max_length=args.training_seq_len), drop_last=True, ) - trainer_cls = { - "tp": EagleTPTrainer, - # "mp": EagleMPTrainer, - }[args.teacher_parallel] - - distill_metadata = { - "base_model_hidden_states": ( - torch.Size([int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, 2048]), - torch.bfloat16, - ), - "aux_hidden_states": ( - torch.Size([int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, 2048 * 3]), - torch.bfloat16, - ), - "base_model_logits": ( - torch.Size( - [int(args.batch_size / len(args.student_ranks)), INPUT_LENGTH, DRAFT_VOCAB_SIZE] - ), - torch.bfloat16, - ), - } - - trainer = trainer_cls(rank, args, tokenizer, distill_metadata) + trainer = EagleTPTrainer(rank, args, tokenizer) trainer.train(train_dataloader) # trainer.save_pretrained("ckpts/fast-trained") def main(): parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") - parser.add_argument( - "--student_devices", type=list, default=[0, 1, 2, 3], help="Devices for student model" + "--model_path", + type=str, + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + help="Path to the model.", ) parser.add_argument( - "--teacher_devices", type=list, default=[4, 5], help="Devices for teacher model" + "--student_devices", type=list, default=[0, 1, 2, 3], help="Devices for student model" ) parser.add_argument( - "--teacher_parallel", - type=str, - choices=["tp", "mp"], - default="tp", - help="Parallel type for teacher model. TP and MP supported.", + "--teacher_devices", type=list, default=[4, 5], help="Devices for teacher model" ) parser.add_argument( "--data_path", @@ -117,6 +85,18 @@ def main(): default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl", help="Path to the training data.", ) + parser.add_argument( + "--training_seq_len", + type=str, + default=1024, + help="Training sequence length.", + ) + parser.add_argument( + "--eagle_config_path", + type=str, + default="eagle_config.json", + help="Path to the eagle config.", + ) parser.add_argument( "--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing." ) @@ -128,31 +108,17 @@ def main(): parser.add_argument("--master_port", type=str, default="12357", help="Master port.") args = parser.parse_args() - args.eagle_config = EAGLE3_DEFAULT_CFG["config"] # TODO: add sanity check for args - def set_ranks(student_devices, teacher_devices, teacher_parallel): - # TODO(hg): add "no-parallel" option, fallback when only one teacher device is provided. - # TODO(hg): add "FSDP" option. - if teacher_parallel == "tp": - world_size = len(teacher_devices) + len(student_devices) - student_ranks = list(range(len(student_devices))) - teacher_ranks = list( - range(len(student_devices), len(student_devices) + len(teacher_devices)) - ) - elif teacher_parallel == "mp": - raise NotImplementedError("MP parallel type not supported.") - # world_size = 2 - # student_rank = 0 - # teacher_ranks = [1] - else: - raise NotImplementedError(f"Parallel type {teacher_parallel} not supported.") - return world_size, student_ranks, teacher_ranks - - args.world_size, args.student_ranks, args.teacher_ranks = set_ranks( - args.student_devices, args.teacher_devices, args.teacher_parallel - ) + def set_ranks(args): + # TODO(hg): add "no-parallel", "MP", "FSDP". + args.world_size = len(args.teacher_devices) + len(args.student_devices) + args.student_ranks = list(range(len(args.student_devices))) + args.teacher_ranks = list( + range(len(args.student_devices), len(args.student_devices) + len(args.teacher_devices)) + ) + set_ranks(args) # Launch multiple processes mp.spawn( train, From 1a6325495a31271131d77a6552378870ba648da3 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:01:45 +0000 Subject: [PATCH 6/9] clean up codes Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 180 ++++++++++-------- examples/speculative_decoding/train.py | 10 +- .../torch/speculative/plugins/transformers.py | 16 +- 3 files changed, 111 insertions(+), 95 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index 31eebd924..cda9d0975 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -17,6 +17,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" from abc import abstractmethod +from contextlib import nullcontext import torch import torch.distributed as dist @@ -24,11 +25,18 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM from transformers.optimization import get_linear_schedule_with_warmup +from transformers.utils import ModelOutput import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG +try: + import wandb +except ImportError: + wandb = None + + mto.enable_huggingface_checkpointing() # Hyperparameters for profiling @@ -51,12 +59,13 @@ class BaseDistillTrainer: student_step: student step function. """ - def __init__(self, rank, args, tokenizer): + def __init__(self, rank, args, tokenizer, dataloader): self.rank = rank args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) args.student_pgroup = dist.new_group(ranks=args.student_ranks) self.args = args self.tokenizer = tokenizer + self.dataloader = dataloader if rank in args.student_ranks: self.model = self.prepare_student_model() self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) @@ -71,14 +80,6 @@ def _print_model_placement(self, module): for name, param in module.named_parameters(): print(f"(Rank {self.rank}) {name} ---> {param.device} ") - @property - def current_rank_device(self): - pass - - @property - def distill_metadata(self): - pass - def _reset_all_mem_stats(self): torch.cuda.reset_max_memory_allocated(self.current_rank_device) @@ -86,31 +87,42 @@ def _print_mem_stats(self): max_mem = torch.cuda.max_memory_allocated(self.current_rank_device) print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB") + @property + def current_rank_device(self): + """Return device of the current rank.""" + + @property + def distill_metadata(self): + """Return a DistillMetadata that describe the distillation message received by student.""" + @abstractmethod - def load_teacher_model(self): - pass + def prepare_teacher_model(self): + """Return coverted teacher model with correct parallelization.""" @abstractmethod - def load_student_model(self): - pass + def prepare_student_model(self): + """Return coverted student model with correct parallelization.""" @abstractmethod - def teacher_step(self, *args, **kwargs) -> dict[str, torch.Tensor]: - pass + def teacher_step(self, *args, **kwargs) -> list[dict[str, torch.Tensor]]: + """Run one student step and return distillation messages for each student rank.""" @abstractmethod - def student_step(self, *args, **kwargs): - pass + def student_step(self, *args, **kwargs) -> ModelOutput: + """Run forward of student step, return a modeloutput object.""" - def save_pretrained(self, path=None): + def save_pretrained(self, save_path): + """Save the model and tokenizer.""" if self.rank == self.args.student_ranks[0]: - path = self.args.out_path if path is None else path - self.model.save_pretrained(path) - self.tokenizer.save_pretrained(path) - print(f"Pretrained model saved to {path}") + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + self.model.module.save_pretrained(save_path) + else: + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + print(f"Pretrained model saved to {save_path}") def _check_valid_message(self, message: dict[str, torch.Tensor]): - # Check if keys and length match between message and distill_metadata + """Check if message in the format of distill_metadata.""" if set(message.keys()) != set(self.distill_metadata.keys()): raise ValueError( f"Message keys: {set(message.keys())} \n" @@ -142,8 +154,8 @@ def _recv_from_teacher(self): for req in reqs: req.wait() - def _get_distill_kwargs(self): - """Return a copy of received buffer for student training.""" + def _clone_recv_buffer(self): + """Return a copy of received tensors for student step input.""" return {k: v.clone().detach() for k, v in self.student_recv_buffer.items()} def _send_to_student(self, teacher_outputs): @@ -160,49 +172,63 @@ def _send_to_student(self, teacher_outputs): for req in reqs: req.wait() - def train(self, dataloader): + def _get_logging_context(self): + print( + f"Rank {self.rank} is logging: {wandb is not None and self.rank == self.args.student_ranks[0]}" + ) + if wandb is not None and self.rank == self.args.student_ranks[0]: + return wandb.init( + entity=os.environ["WANDB_ENTITY"], + project=os.environ["WANDB_PROJECT"], + config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, + ) + return nullcontext() + + def train(self): """Main training entrance of the composed model.""" self._reset_all_mem_stats() if self.rank in self.args.student_ranks: - import wandb - - wandb.login() - - with wandb.init( - entity=os.environ["WANDB_ENTITY"], - project=os.environ["WANDB_PROJECT"], - config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, - ) as run: + with self._get_logging_context() as run: self._init_student_recv_buffer() - wandb.watch(self.model, log="all") + # Student training loop for epoch in range(EPOCHS): pbar = ( - tqdm(dataloader) if self.rank == self.args.student_ranks[0] else dataloader + tqdm(self.dataloader) + if self.rank == self.args.student_ranks[0] + else self.dataloader ) for i, batch in enumerate(pbar): - global_step = epoch * len(dataloader) + i + global_step = epoch * len(self.dataloader) + i inputs = {k: v.to(self.model.device) for k, v in batch.items()} + + # Receive distill messages from teacher self._recv_from_teacher() - loss, train_acc = self.student_step(inputs, **self._get_distill_kwargs()) + # Run forward of student step + output = self.student_step(inputs, **self._clone_recv_buffer()) + loss = output.loss + + # Run backward step + loss.backward() + self.optimizer.step() + self.scheduler.step() + + # Log and save only on student rank 0 if self.rank != self.args.student_ranks[0]: continue - pbar.set_description(f"Epoch {epoch} Loss:{loss} Acc:{train_acc}") + train_metrics = { + "loss": round(loss.item(), 3), + "lr": self.optimizer.param_groups[0]["lr"], + # Attach all float metrics + **{k: round(v, 3) for k, v in output.items() if isinstance(v, float)}, + } + + pbar.set_description(f"Epoch {epoch} Loss {train_metrics['loss']}") if global_step % LOG_INTERVAL == 0: - run.log( - { - "loss": loss, - "train_acc_step0": train_acc[0], - "train_acc_step1": train_acc[1], - "train_acc_step2": train_acc[2], - "train_acc_step3": train_acc[3], - "lr": self.optimizer.param_groups[0]["lr"], - }, - step=global_step, - ) + run.log(train_metrics, step=global_step) if global_step > 0 and global_step % SAVE_INTERVAL == 0: self.save_pretrained( f"{self.args.out_path}/epoch_{epoch}_step_{global_step}" @@ -211,13 +237,10 @@ def train(self, dataloader): else: # Inference Loop for epoch in range(EPOCHS): - for i, batch in enumerate(dataloader): - global_step = epoch * len(dataloader) + i + for i, batch in enumerate(self.dataloader): inputs = {k: v.to(self.model.device) for k, v in batch.items()} - inputs["position_ids"] = None with torch.inference_mode(): - teacher_outputs = self.teacher_step(self.model, inputs) - self._send_to_student(teacher_outputs) + self._send_to_student(self.teacher_step(self.model, inputs)) self._print_mem_stats() # Makesure all processes finished before destroy. @@ -227,14 +250,15 @@ def train(self, dataloader): class EagleTPTrainer(BaseDistillTrainer): - def __init__(self, rank, args, tokenizer): + def __init__(self, rank, args, tokenizer, dataloader): + # Load eagle config args.eagle_config = EAGLE3_DEFAULT_CFG["config"] if args.eagle_config_path: with open(args.eagle_config_path) as f: custom_config = json.load(f) args.eagle_config["eagle_architecture_config"].update(custom_config) - super().__init__(rank, args, tokenizer) + super().__init__(rank, args, tokenizer, dataloader) @property def current_rank_device(self): @@ -245,6 +269,7 @@ def current_rank_device(self): @property def distill_metadata(self) -> DistillMetadata: + """Description of the distillation signal received by student.""" return { "base_model_hidden_states": ( torch.Size( @@ -279,12 +304,14 @@ def distill_metadata(self) -> DistillMetadata: } def prepare_teacher_model(self): + # Load model with TP among teacher ranks. model = AutoModelForCausalLM.from_pretrained( self.args.model_path, torch_dtype="auto", tp_plan="auto", device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), ) + # load eagle config and convert. self.args.eagle_config["eagle_architecture_config"].update( { "hidden_size": model.config.hidden_size, @@ -298,7 +325,6 @@ def prepare_teacher_model(self): return model def prepare_student_model(self): - """Load student model on a single device and keep needed modules from teacher.""" # Load to CPU first to avoid OOM model = AutoModelForCausalLM.from_pretrained( self.args.model_path, torch_dtype="auto", device_map="cpu" @@ -331,15 +357,19 @@ def prepare_student_model(self): return model def teacher_step(self, model, inputs): + # Collect base model outputs. base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( **inputs, freeze_base_model=True, past_key_values=None, ) - # aux_hidden_states could be on multiple devices. Gather them and cat. + + # Aux_hidden_states could be on multiple devices. Gather before cat. aux_hidden_states = torch.cat( [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 ) + + # Chunk the tensors for each student rank. base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) @@ -356,28 +386,12 @@ def teacher_step(self, model, inputs): def student_step( self, inputs, - base_model_hidden_states, - aux_hidden_states, - base_model_logits, - ): + **distill_msgs, + ) -> ModelOutput: self.optimizer.zero_grad() - # Second stage forward using the unified model + + # Chunk inputs for each student rank. inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} - output = self.model( - **inputs, - # providing base model outputs to bypass the base model forward. - base_model_outputs={ - "base_model_hidden_states": base_model_hidden_states, - "aux_hidden_states": aux_hidden_states.clone().detach(), - "base_model_logits": base_model_logits.clone().detach(), - }, - ) - loss = output.loss - # print(f"Rank {self.rank} loss: {loss.item()}") - train_acc = output.train_acc - - # Backward - loss.backward() - self.optimizer.step() - self.scheduler.step() - return round(loss.item(), 3), train_acc + + # Second stage forward with provided base model outputs. + return self.model(**inputs, base_model_outputs=distill_msgs) diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index 4a01333ce..505b6ee18 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -60,9 +60,9 @@ def train(rank, args): drop_last=True, ) - trainer = EagleTPTrainer(rank, args, tokenizer) - trainer.train(train_dataloader) - # trainer.save_pretrained("ckpts/fast-trained") + trainer = EagleTPTrainer(rank, args, tokenizer, train_dataloader) + trainer.train() + trainer.save_pretrained(args.out_path) def main(): @@ -104,7 +104,9 @@ def main(): "--out_path", type=str, default="ckpts/fast-trained", help="Path to save the model." ) parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") - parser.add_argument("--batch_size", type=int, default=4, help="Batch size.") + parser.add_argument( + "--batch_size", type=int, default=4, help="Total batch size across all parallel ranks." + ) parser.add_argument("--master_port", type=str, default="12357", help="Master port.") args = parser.parse_args() diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index ad0b32074..7afbd1e1c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -586,10 +586,10 @@ def _base_model_forward( self, input_ids, attention_mask, - position_ids, - past_key_values, - freeze_base_model, - labels, + position_ids=None, + past_key_values=None, + freeze_base_model=True, + labels=None, **kwargs, ): # TODO: This function still use eagle_module. Ideally we should remove it, @@ -726,7 +726,7 @@ def forward( # ====Run eagle forward==== eagle_loss = None - train_accs = [] + train_accs = {} if self.training: # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers b, seq_length, h = base_model_hidden_states.shape @@ -770,7 +770,7 @@ def forward( loss_mask[:, 1:], ) eagle_loss = classification_loss - train_accs.append(acc) + train_accs["train_acc_step0"] = acc # ====Perform training-time-testing with 3 extra eagle forward passes==== for ttt_step in range(self.num_ttt_steps): @@ -811,7 +811,7 @@ def forward( ), ) eagle_loss += classification_loss - train_accs.append(acc) + train_accs[f"train_acc_step{ttt_step + 1}"] = acc # Finally, we merge base model loss and eagle loss, raise error if both are None if base_model_loss is not None and eagle_loss is not None: loss = base_model_loss + eagle_loss @@ -830,7 +830,7 @@ def forward( logits=base_model_logits, past_key_values=past_key_values, hidden_states=base_model_hidden_states, - train_acc=train_accs, + **train_accs, ) def _eagle_loss( From 5daa2399966dc9aaf58b2edaaeff4ea36a0d75cf Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:11:00 +0000 Subject: [PATCH 7/9] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 22 +++++---- examples/speculative_decoding/train.py | 45 +++++-------------- 2 files changed, 25 insertions(+), 42 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index cda9d0975..b1b7fa55c 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -40,10 +40,8 @@ mto.enable_huggingface_checkpointing() # Hyperparameters for profiling -EPOCHS = 1 LOG_INTERVAL = 100 SAVE_INTERVAL = 20000 -# VALIDATE_INTERVAL = 20 # Shape and dtype description of the distillation signal DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]] @@ -61,11 +59,11 @@ class BaseDistillTrainer: def __init__(self, rank, args, tokenizer, dataloader): self.rank = rank - args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) - args.student_pgroup = dist.new_group(ranks=args.student_ranks) self.args = args self.tokenizer = tokenizer self.dataloader = dataloader + + # Prepare models if rank in args.student_ranks: self.model = self.prepare_student_model() self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) @@ -180,7 +178,11 @@ def _get_logging_context(self): return wandb.init( entity=os.environ["WANDB_ENTITY"], project=os.environ["WANDB_PROJECT"], - config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size}, + config={ + "epochs": self.args.epoch, + "lr": self.args.lr, + "batch_size": self.args.batch_size, + }, ) return nullcontext() @@ -193,7 +195,7 @@ def train(self): self._init_student_recv_buffer() # Student training loop - for epoch in range(EPOCHS): + for epoch in range(self.args.epoch): pbar = ( tqdm(self.dataloader) if self.rank == self.args.student_ranks[0] @@ -236,7 +238,7 @@ def train(self): else: # Inference Loop - for epoch in range(EPOCHS): + for epoch in range(self.args.epoch): for i, batch in enumerate(self.dataloader): inputs = {k: v.to(self.model.device) for k, v in batch.items()} with torch.inference_mode(): @@ -390,8 +392,10 @@ def student_step( ) -> ModelOutput: self.optimizer.zero_grad() - # Chunk inputs for each student rank. + # Chunk input_ids and attention_mask for each student rank. inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} # Second stage forward with provided base model outputs. - return self.model(**inputs, base_model_outputs=distill_msgs) + output = self.model(**inputs, base_model_outputs=distill_msgs) + + return output diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index 505b6ee18..b9e5f2148 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -41,6 +41,8 @@ def _setup_distributed(rank, args, backend="nccl"): print( f"Starting process rank={rank}, device={torch.cuda.current_device()}, world_size={args.world_size}" ) + args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks) + args.student_pgroup = dist.new_group(ranks=args.student_ranks) def train(rank, args): @@ -67,47 +69,24 @@ def train(rank, args): def main(): parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example") + parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + parser.add_argument("--student_devices", type=list, default=[0, 1, 2, 3]) + parser.add_argument("--teacher_devices", type=list, default=[4, 5]) parser.add_argument( - "--model_path", - type=str, - default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - help="Path to the model.", - ) - parser.add_argument( - "--student_devices", type=list, default=[0, 1, 2, 3], help="Devices for student model" - ) - parser.add_argument( - "--teacher_devices", type=list, default=[4, 5], help="Devices for teacher model" - ) - parser.add_argument( - "--data_path", - type=str, - default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl", - help="Path to the training data.", - ) - parser.add_argument( - "--training_seq_len", - type=str, - default=1024, - help="Training sequence length.", - ) - parser.add_argument( - "--eagle_config_path", - type=str, - default="eagle_config.json", - help="Path to the eagle config.", + "--data_path", type=str, default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl" ) + parser.add_argument("--training_seq_len", type=str, default=1024) + parser.add_argument("--eagle_config_path", type=str, default="eagle_config.json") parser.add_argument( "--lazy_preprocess", type=bool, default=True, help="Whether to use lazy preprocessing." ) - parser.add_argument( - "--out_path", type=str, default="ckpts/fast-trained", help="Path to save the model." - ) - parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--out_path", type=str, default="ckpts/fast-trained") + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--epoch", type=int, default=1) parser.add_argument( "--batch_size", type=int, default=4, help="Total batch size across all parallel ranks." ) - parser.add_argument("--master_port", type=str, default="12357", help="Master port.") + parser.add_argument("--master_port", type=str, default="12357") args = parser.parse_args() # TODO: add sanity check for args From 35cc9a82597c74fc43f3d655c8da9e4a66f90510 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:18:34 +0000 Subject: [PATCH 8/9] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/distill_trainer.py | 5 ++++- examples/speculative_decoding/train.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index b1b7fa55c..afac87767 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -49,7 +49,8 @@ class BaseDistillTrainer: """ - Base class for distillation trainer. Initalized and called on every rank. + Base distill trainer with basic training loop and overlapped teacher and student steps. + Initalized and called on every rank. Args: rank: rank of the current process args: arguments @@ -252,6 +253,8 @@ def train(self): class EagleTPTrainer(BaseDistillTrainer): + """A subclass of BaseDistillTrainer for online eagle training, with base model TP and student DDP.""" + def __init__(self, rank, args, tokenizer, dataloader): # Load eagle config args.eagle_config = EAGLE3_DEFAULT_CFG["config"] diff --git a/examples/speculative_decoding/train.py b/examples/speculative_decoding/train.py index b9e5f2148..20ce5e7c4 100644 --- a/examples/speculative_decoding/train.py +++ b/examples/speculative_decoding/train.py @@ -92,7 +92,7 @@ def main(): # TODO: add sanity check for args def set_ranks(args): - # TODO(hg): add "no-parallel", "MP", "FSDP". + # TODO(hg): This is for TP-DDP setting only. Add "no-parallel", "MP", "FSDP". args.world_size = len(args.teacher_devices) + len(args.student_devices) args.student_ranks = list(range(len(args.student_devices))) args.teacher_ranks = list( From 8d6a49b5599f7d957eb2c6551b9ae7678cc1b0fc Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:26:19 +0000 Subject: [PATCH 9/9] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative_decoding/distill_trainer.py | 86 +++++++++---------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/examples/speculative_decoding/distill_trainer.py b/examples/speculative_decoding/distill_trainer.py index afac87767..d39f9d57e 100644 --- a/examples/speculative_decoding/distill_trainer.py +++ b/examples/speculative_decoding/distill_trainer.py @@ -66,13 +66,13 @@ def __init__(self, rank, args, tokenizer, dataloader): # Prepare models if rank in args.student_ranks: - self.model = self.prepare_student_model() + self.model = self._prepare_student_model() self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=0, num_training_steps=117380 ) else: - self.model = self.prepare_teacher_model() + self.model = self._prepare_teacher_model() self._print_model_placement(self.model) def _print_model_placement(self, module): @@ -95,11 +95,11 @@ def distill_metadata(self): """Return a DistillMetadata that describe the distillation message received by student.""" @abstractmethod - def prepare_teacher_model(self): + def _prepare_teacher_model(self): """Return coverted teacher model with correct parallelization.""" @abstractmethod - def prepare_student_model(self): + def _prepare_student_model(self): """Return coverted student model with correct parallelization.""" @abstractmethod @@ -272,43 +272,7 @@ def current_rank_device(self): else: return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] - @property - def distill_metadata(self) -> DistillMetadata: - """Description of the distillation signal received by student.""" - return { - "base_model_hidden_states": ( - torch.Size( - [ - int(self.args.batch_size / len(self.args.student_ranks)), - self.args.training_seq_len, - 2048, - ] - ), - torch.bfloat16, - ), - "aux_hidden_states": ( - torch.Size( - [ - int(self.args.batch_size / len(self.args.student_ranks)), - self.args.training_seq_len, - 2048 * 3, - ] - ), - torch.bfloat16, - ), - "base_model_logits": ( - torch.Size( - [ - int(self.args.batch_size / len(self.args.student_ranks)), - self.args.training_seq_len, - self.args.draft_vocab_size, - ] - ), - torch.bfloat16, - ), - } - - def prepare_teacher_model(self): + def _prepare_teacher_model(self): # Load model with TP among teacher ranks. model = AutoModelForCausalLM.from_pretrained( self.args.model_path, @@ -324,12 +288,11 @@ def prepare_teacher_model(self): "draft_vocab_size": model.config.vocab_size, } ) - self.args.draft_vocab_size = model.config.vocab_size mtsp.convert(model, [("eagle", self.args.eagle_config)]) model.eval() return model - def prepare_student_model(self): + def _prepare_student_model(self): # Load to CPU first to avoid OOM model = AutoModelForCausalLM.from_pretrained( self.args.model_path, torch_dtype="auto", device_map="cpu" @@ -342,7 +305,6 @@ def prepare_student_model(self): "draft_vocab_size": model.config.vocab_size, } ) - self.args.draft_vocab_size = model.config.vocab_size mtsp.convert( model, [("eagle", self.args.eagle_config)], @@ -361,6 +323,42 @@ def prepare_student_model(self): ) return model + @property + def distill_metadata(self) -> DistillMetadata: + """Description of the distillation signal received by student.""" + return { + "base_model_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"], + ] + ), + torch.bfloat16, + ), + "aux_hidden_states": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["hidden_size"] * 3, + ] + ), + torch.bfloat16, + ), + "base_model_logits": ( + torch.Size( + [ + int(self.args.batch_size / len(self.args.student_ranks)), + self.args.training_seq_len, + self.args.eagle_config["eagle_architecture_config"]["draft_vocab_size"], + ] + ), + torch.bfloat16, + ), + } + def teacher_step(self, model, inputs): # Collect base model outputs. base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(