From 281357d72866525ebc9715cc556ebce34b8faf92 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 25 Sep 2025 05:53:37 -0700 Subject: [PATCH] Copy changes made to Megatron-LM Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 29 +++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 7078cca36..c1fa45f6b 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -59,7 +59,7 @@ class DistillationConfig: logit_kl_temperature: Temperature for the logit KL-divergence loss. """ - intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list) + intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) logit_layers: tuple[str, str] = ("output_layer", "output_layer") skip_lm_loss: bool = True kd_loss_scale: float = 1.0 @@ -69,12 +69,28 @@ class DistillationConfig: def __post_init__(self): assert len(self.logit_layers) == 2, f"{self.logit_layers=}" - assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), ( + assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), ( f"{self.intermediate_layer_pairs=}" ) assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}" assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}" + @staticmethod + def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]: + """Parse an intermediate entry into a student layer, teacher layer, and loss function.""" + if len(entry) == 3: + student_layer, teacher_layer, loss_fn_name = entry + if loss_fn_name == "cosine": + loss_fn = HiddenStateCosineLoss + elif loss_fn_name == "mse": + loss_fn = MSELoss + else: + raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}") + else: + student_layer, teacher_layer = entry + loss_fn = HiddenStateCosineLoss # default to cosine loss + return student_layer, teacher_layer, loss_fn + def load_distillation_config( config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig" @@ -105,7 +121,8 @@ def load_distillation_config( # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) - for student_layer, teacher_layer in cfg.intermediate_layer_pairs: + for entry in cfg.intermediate_layer_pairs: + student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry) if parallel_state.get_tensor_and_context_parallel_rank() == 0: logger.info( "Distillation: Adding intermediate loss between" @@ -114,7 +131,7 @@ def load_distillation_config( ) student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg) teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg) - criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss( + criterion[(student_layer, teacher_layer)] = loss_fn( student_cfg, projection_layer=projection_layer ) @@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions, targets = self.pre_forward(predictions, targets) loss = F.mse_loss(predictions, targets, reduction="none") - loss = loss.sum(dim=-1) + loss = loss.mean(dim=-1) - return self.post_forward(loss) + return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel) class HiddenStateCosineLoss(BaseLoss):