Skip to content

Conversation

AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Sep 15, 2025

What does this PR do?

Plugin feature: Updated Megatron KD plugin module

Overview: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Y
  • Did you write any new necessary tests?: N
  • Did you add or update any necessary documentation?: N
  • Did you update Changelog?: N

Additional Information

Summary by CodeRabbit

  • New Features

    • Introduced a typed DistillationConfig with YAML loading and sensible defaults.
    • Added pipeline-parallel distillation support, including stage-aware output handling and tensor shape adjustments.
    • Enabled flexible per-layer loss parameters via keyword arguments.
    • Structured loss reporting (kd_loss, logits_loss, intermediate_loss) for clearer monitoring.
  • Refactor

    • Unified loss APIs around a single model-config–based interface; constructors updated accordingly.
    • Standardized loss balancing to handle multiple components.
  • Documentation

    • Updated docstrings to reflect new configuration options and loss parameterization.

Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
@AAnoosheh AAnoosheh self-assigned this Sep 15, 2025
Copy link

copy-pr-bot bot commented Sep 15, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 15, 2025

Walkthrough

The distillation API was updated to pass arbitrary kwargs to per-layer loss functions via compute_kd_loss. In the Megatron plugin, a new DistillationConfig dataclass replaces dict configs, loss classes were refactored to a model_config-based API, pipeline-parallel handling was added, and loss balancing now returns structured components.

Changes

Cohort / File(s) Summary of changes
KD loss API broadening
modelopt/torch/distill/distillation_model.py
Replaced explicit labels arg with variadic **loss_fn_kwargs; forwarded kwargs to loss functions; updated docstrings; no other logic changes.
Megatron distillation refactor & pipeline support
modelopt/torch/distill/plugins/megatron.py
Introduced DistillationConfig dataclass; load_distillation_config returns config object; added layer-index adjustment for PP; refactored losses to take model_config; updated ProjectionLayer; made post-forward return (loss, tp_reduce, is_sequence_parallel); added pipeline-aware student/teacher input/output handling; updated adjust_distillation_model_for_mcore to use config; loss balancer now returns dict with kd_loss, logits_loss, intermediate_loss; expanded typing/imports.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant Trainer
    participant MegatronPlugin as Megatron Plugin
    participant DistillCfg as DistillationConfig
    participant DistillModel as DistillationModel
    participant Student
    participant Teacher
    participant Losses as LossFns/Balancer

    User->>Trainer: start training
    Trainer->>MegatronPlugin: load_distillation_config(path, student_cfg, teacher_cfg)
    MegatronPlugin-->>Trainer: DistillationConfig
    Trainer->>MegatronPlugin: adjust_distillation_model_for_mcore(DistillationModel, DistillationConfig)
    MegatronPlugin->>DistillModel: patch for PP, hide teacher, LM-loss bypass
    Note over DistillModel: Pipeline-aware forward hooks installed

    loop each batch
        Trainer->>DistillModel: forward(inputs)
        activate DistillModel
        DistillModel->>Student: forward(student_inputs)
        DistillModel->>Teacher: forward(teacher_inputs)
        Teacher-->>DistillModel: teacher outputs
        Student-->>DistillModel: student outputs
        DistillModel-->>Trainer: concatenated or student-only outputs (PP-aware)
        deactivate DistillModel

        Trainer->>DistillModel: compute_kd_loss(**loss_fn_kwargs)
        DistillModel->>Losses: per-layer losses(out_s, out_t, **kwargs)
        Losses-->>DistillModel: logits/intermediate losses
        DistillModel->>Losses: balance(loss_dict, scales, skip_balancer?)
        Losses-->>DistillModel: {kd_loss, logits_loss, intermediate_loss}
        DistillModel-->>Trainer: kd_loss or dict
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A rabbit taps its keys with glee,
Configs hop from dicts to d-classes free.
Losses align, kwargs in tow,
Pipelines parade in a parallel flow.
KL whispers, cosine hums light—
Distilled moonlight, models grow bright. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Update distill Megatron plugin" is concise and directly related to the primary change set (updates and refactors to the Megatron distillation plugin, including the new DistillationConfig and loss plumbing), so it accurately signals the PR's intent to reviewers scanning history. It is not misleading or generic.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch aanoosheh/update-megatron-kd

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

codecov bot commented Sep 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.88%. Comparing base (85b309f) to head (935e666).
⚠️ Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #319      +/-   ##
==========================================
+ Coverage   73.87%   73.88%   +0.01%     
==========================================
  Files         172      172              
  Lines       17439    17443       +4     
==========================================
+ Hits        12883    12888       +5     
+ Misses       4556     4555       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AAnoosheh AAnoosheh marked this pull request as ready for review September 16, 2025 09:22
@AAnoosheh AAnoosheh requested a review from a team as a code owner September 16, 2025 09:22
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/distill/distillation_model.py (1)

271-276: Potential incompatibility: loss_fn must return a Tensor.

Several new Megatron losses return a tuple (loss, tp_reduce, is_sequence_parallel). Passing that through here will break balancing and masking.

Two safe options:

  • A) Keep loss fns returning Tensor only (preferred for compatibility).
  • B) Teach compute_kd_loss to accept tuples and normalize to a Tensor before reduction/balancing.

If you choose (A), see my Megatron comments to revert post_forward to return only a Tensor.
If (B), I can draft a minimal normalizer that handles TP/SP flags. Want me to?

modelopt/torch/distill/plugins/megatron.py (1)

372-421: Loss balancer returns a dict and assumes original loss exists; both will break training.

Issues:

  • forward must return a scalar Tensor per DistillationLossBalancer contract.
  • Unconditional pop of student loss raises KeyError when not provided or when skip_original_loss is True.
  • Summing Tensors via Python sum with empty start causes type errors.
  • Comparing Tensors with “> 0” is ambiguous if not 0‑dim.

Apply robust, contract‑preserving changes:

-    def forward(self, loss_dict: dict[str, Tensor]) -> Tensor:
+    def forward(self, loss_dict: dict[str, Tensor]) -> Tensor:
         """Forward function.
@@
-        original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY)
-        for _key in loss_dict:
-            if _key.startswith(LogitsKLLoss.__name__):
-                logits_key = _key  # should only be one
-        logits_loss = loss_dict.pop(logits_key)
-        intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1)
+        # Work on a copy to avoid mutating caller state.
+        loss_dict = dict(loss_dict)
+        original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY, None)
+        # Extract logits loss
+        logits_keys = [k for k in loss_dict if k.startswith(LogitsKLLoss.__name__)]
+        if len(logits_keys) != 1:
+            raise ValueError(f"Expected exactly one logits loss, found: {logits_keys}")
+        logits_loss = loss_dict.pop(logits_keys[0])
+        # Normalize to scalars
+        def _to_scalar(t: Tensor) -> Tensor:
+            return t.mean() if t.dim() > 0 else t
+        logits_loss = _to_scalar(logits_loss)
+        interm_values = list(loss_dict.values())
+        if interm_values:
+            intermediate_loss = _to_scalar(torch.stack([_to_scalar(v) for v in interm_values]).mean())
+        else:
+            intermediate_loss = torch.zeros_like(logits_loss)
@@
-        if intermediate_loss > 0:
-            dynamic_scale = logits_loss.item() / intermediate_loss.item()
-            intermediate_loss_scaled = intermediate_loss * dynamic_scale
-        else:
-            intermediate_loss = logits_loss.new_tensor(intermediate_loss)
-            intermediate_loss_scaled = intermediate_loss
+        if intermediate_loss.detach().abs().item() > 0:
+            dynamic_scale = logits_loss.detach().item() / intermediate_loss.detach().item()
+            intermediate_loss_scaled = intermediate_loss * dynamic_scale
+        else:
+            intermediate_loss_scaled = torch.zeros_like(logits_loss)
@@
-        if self._skip_original_loss:
-            total_loss = logits_loss + intermediate_loss_scaled
-        else:
-            kd_loss = logits_loss + intermediate_loss_scaled
-            kd_loss *= original_loss.item() / kd_loss.item()
-            total_loss = original_loss + kd_loss * self._kd_loss_scale
+        if self._skip_original_loss or original_loss is None:
+            total_loss = logits_loss + intermediate_loss_scaled
+        else:
+            kd_loss = logits_loss + intermediate_loss_scaled
+            kd_loss = kd_loss * (original_loss.detach().item() / max(kd_loss.detach().item(), 1e-12))
+            total_loss = original_loss + kd_loss * self._kd_loss_scale
@@
-        out_dict = {
-            "kd_loss": total_loss,
-            "logits_loss": logits_loss,
-            "intermediate_loss": intermediate_loss,
-        }
-        return out_dict
+        # Optional: expose components for logging
+        self.last_components = {
+            "total_loss": total_loss.detach(),
+            "logits_loss": logits_loss.detach(),
+            "intermediate_loss": intermediate_loss.detach(),
+        }
+        return total_loss
🧹 Nitpick comments (4)
modelopt/torch/distill/distillation_model.py (2)

251-257: Docstring still references “reduce=True/False”.

The public API uses skip_balancer, not reduce. Update to avoid confusion.

Apply:

-        Returns:
-            If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
-            If reduce is False, a dict of student model output loss and layer-wise distillation losses.
+        Returns:
+            If ``skip_balancer`` is False, the total loss as returned by the configured loss balancer
+            (typically a scalar Tensor). If ``skip_balancer`` is True, a dict of student loss and
+            layer‑wise distillation losses.

123-127: Type hint nit: teacher_model is not a ModuleList.

Return annotation should be nn.Module.

Apply:

-    def teacher_model(self) -> nn.ModuleList:
+    def teacher_model(self) -> nn.Module:
modelopt/torch/distill/plugins/megatron.py (2)

50-69: DistillationConfig: make internal fields non‑init to prevent YAML collisions.

criterion and loss_balancer are derived; disallow user‑provided values to avoid accidental overrides.

-    criterion: Criterion | None = None
-    loss_balancer: mtd.DistillationLossBalancer | None = None
+    criterion: Criterion | None = field(default=None, init=False, repr=False)
+    loss_balancer: mtd.DistillationLossBalancer | None = field(default=None, init=False, repr=False)

553-616: Tensor shapes adjuster: safe against empty shapes.

Guard for empty recv/send shapes to avoid index errors in corner scheduling cases.

-    def adjust_tensor_shapes(
+    def adjust_tensor_shapes(
         recv_tensor_shapes: list[tuple[int, ...]], send_tensor_shapes: list[tuple[int, ...]]
     ):
+        if not recv_tensor_shapes or not send_tensor_shapes:
+            return recv_tensor_shapes, send_tensor_shapes
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 935e666.

📒 Files selected for processing (2)
  • modelopt/torch/distill/distillation_model.py (3 hunks)
  • modelopt/torch/distill/plugins/megatron.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/distill/plugins/megatron.py (2)
modelopt/torch/distill/distillation_model.py (5)
  • loss_balancer (134-136)
  • DistillationModel (37-288)
  • hide_teacher_model (139-147)
  • teacher_model (124-126)
  • only_student_forward (171-178)
modelopt/torch/distill/loss_balancers.py (1)
  • DistillationLossBalancer (33-71)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/distill/plugins/megatron.py (1)

488-551: Guard set_input_tensor calls or ensure the teacher implements it.

Repo search found no def set_input_tensor — the call type(self).set_input_tensor(self.teacher_model, ...) in modelopt/torch/distill/plugins/megatron.py (_set_input_tensor) will raise if missing; add an hasattr guard and fallback (e.g. call self.teacher_model.set_input_tensor(...) if present) or raise a clear error.

modelopt/torch/distill/distillation_model.py (1)

242-244: Approve: kwargs passthrough OK — no call sites pass positional extras.
Scanned repo for compute_kd_loss usages; all calls either pass keyword args or no args: tests/unit/torch/distill/test_distill.py (lines 94, 111, 127, 188); examples/chained_optimizations/bert_prune_distill_quantize.py:1035; modelopt/torch/distill/plugins/huggingface.py:96; modelopt/torch/quantization/plugins/transformers_trainer.py:402.

Comment on lines +181 to 187
def post_forward(
self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
return (loss, tp_reduce)
return (loss, tp_reduce, is_sequence_parallel)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Breaking change: BaseLoss.post_forward now returns a tuple.

This violates the expected Loss API (forward → Tensor) and breaks DistillationModel.compute_kd_loss which assumes Tensor. Revert to returning a Tensor; encode TP/SP concerns elsewhere.

-    def post_forward(
-        self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
-    ) -> Tensor:
+    def post_forward(self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False) -> Tensor:
         """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
         loss = loss.transpose(0, 1).contiguous()
-        return (loss, tp_reduce, is_sequence_parallel)
+        return loss

And adjust callers (see comments below).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def post_forward(
self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
return (loss, tp_reduce)
return (loss, tp_reduce, is_sequence_parallel)
def post_forward(self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
return loss
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 181 to 187,
post_forward currently returns a tuple (loss, tp_reduce, is_sequence_parallel)
which breaks the Loss API that expects a Tensor; revert post_forward to return
only the Tensor (i.e., transpose and contiguous as before) and remove the tuple
packaging; move TP/SP flags out of this return (e.g., set attributes on the
plugin instance, provide accessor methods, or pass flags via the distillation
caller/compute_kd_loss invocation) and update any callers (notably
DistillationModel.compute_kd_loss and other code expecting a Tensor) to retrieve
TP/SP information from the new location rather than from post_forward's return
value.

Comment on lines +217 to 233
def __init__(
self, model_config: "TransformerConfig", projection_layer: nn.Module | None = None
):
"""Constructor.
Args:
student_config: Student's MCore transformer config.
teacher_config: Teacher's MCore transformer config.
model_config: MCore transformer config.
projection_layer: Module which projects student activations to teacher's hidden dim.
"""
super().__init__(student_config, teacher_config, projection_layer=True)
super().__init__(model_config, projection_layer=projection_layer)

if self._tensor_parallel and not self._sequence_parallel:
if self._config.tensor_model_parallel_size > 1:
logger.warning(
"``HiddenStateCosineLoss`` only works with tensors with full hidden dim. Ensure the "
"tensor inputs meet this requirement or use `--sequence_parallel` if tensor parallel is enabled."
"tensor inputs meet this requirement. We recommend only applying this loss to LayerNorm outputs, "
"which have full hidden dim even when TP is used."
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Align HiddenStateCosineLoss with Tensor‑only contract.

Return a Tensor from forward; remove tuple propagation.

-        return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel)
+        return self.post_forward(loss)

Also applies to: 255-257

🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 217-233 (and also
apply same change at 255-257), the HiddenStateCosineLoss implementation
currently propagates a tuple output but the contract requires returning a single
Tensor from forward; update the forward method to return only the loss Tensor
(not a tuple) and remove any tuple wrapping or propagation in __init__ or helper
methods so all call sites receive a Tensor; ensure type hints and docstring
reflect Tensor return and update any downstream unpacking to accept a single
Tensor.

Comment on lines 262 to 273
def __init__(
self,
student_config: TransformerConfig,
teacher_config: TransformerConfig,
temperature: float = 1.0,
reverse: bool = False,
self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False
):
"""Constructor.
Args:
student_config: Student's MCore transformer config.
teacher_config: Teacher's MCore transformer config.
model_config: MCore transformer config.
temperature: Divide tensors by this value prior to calculating loss.
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
"""
super().__init__(student_config, teacher_config)
super().__init__(model_config)
self._temperature = temperature
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Align LogitsKLLoss with Tensor‑only contract.

Return a Tensor from forward; remove tuple propagation.

-        return self.post_forward(loss, tp_reduce=True)
+        return self.post_forward(loss)

If TP/SP reductions are needed, handle them in the balancer (central place), not by changing loss return types.

Also applies to: 293-364

@kevalmorabia97 kevalmorabia97 self-requested a review September 16, 2025 10:50
@AAnoosheh AAnoosheh merged commit 9aedfdf into main Sep 16, 2025
22 checks passed
@AAnoosheh AAnoosheh deleted the aanoosheh/update-megatron-kd branch September 16, 2025 10:56
yeyu-nvidia pushed a commit that referenced this pull request Sep 18, 2025
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants