Skip to content

Conversation

AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Sep 25, 2025

What does this PR do?

Minor changes

Overview: Copy changes Sharath made to Megatron-LM's distillation code.

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Y
  • Did you write any new necessary tests?: N
  • Did you add or update any necessary documentation?: N
  • Did you update Changelog?: N

Additional Information

Summary by CodeRabbit

  • New Features
    • Per-pair loss selection for intermediate-layer distillation via config, supporting “cosine” or “mse” with a default to cosine.
    • More flexible config entries are accepted and parsed into (student layer, teacher layer, loss).
  • Bug Fixes
    • Corrected MSE reduction to use averaging for consistent scaling.
    • Ensured proper behavior under sequence-parallel execution by propagating the relevant flag through the loss computation path.

@AAnoosheh AAnoosheh self-assigned this Sep 25, 2025
@AAnoosheh AAnoosheh requested a review from a team as a code owner September 25, 2025 12:55
@AAnoosheh AAnoosheh requested a review from ChenhanYu September 25, 2025 12:55
Copy link

coderabbitai bot commented Sep 25, 2025

Walkthrough

Broadened intermediate layer pairing to accept optional per-pair loss spec. Added a parser to map entries to (student, teacher, loss). Configuration loading now uses the parser and applies cosine or MSE per pair. MSE reduction changed to mean. Final loss handling passes is_sequence_parallel through post_forward.

Changes

Cohort / File(s) Summary
Megatron distillation plugin
modelopt/torch/distill/plugins/megatron.py
Added DistillationConfig.parse_intermediate_entry to support per-pair loss selection ("cosine"/"mse") with default to cosine; updated config iteration to use parsed loss; replaced fixed HiddenStateCosineLoss with per-pair loss_fn; changed MSE reduction to average; propagated is_sequence_parallel into post_forward for correct sequence-parallel handling.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User as Config
  participant DC as DistillationConfig
  participant D as Distiller
  participant L as LossFactory
  participant P as PostForward

  rect rgba(230,245,255,0.6)
    Note over User,DC: Load intermediate layer entries
    User->>DC: entries: [(s,t[,loss]), ...]
    loop for each entry
      DC->>DC: parse_intermediate_entry(entry)
      DC-->>D: (student_idx, teacher_idx, loss_kind)
      D->>L: build loss_fn(loss_kind)
      L-->>D: cosine or mse loss
      D->>D: register pair with loss_fn
    end
  end

  rect rgba(240,255,230,0.6)
    Note over D,P: Forward & loss aggregation
    D->>D: compute per-pair losses<br/>(MSE uses mean reduction)
    D->>P: post_forward(total_loss, is_sequence_parallel)
    P-->>D: final loss (seq-parallel aware)
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibble configs, hop through pairs,
Cosine here, a mean MSE there.
Layers whisper teacher lore,
Student ears perk—“parse some more!”
Sequence winds in parallel streams—
I thump the ground: aligned our dreams. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title “Copy changes made to Megatron-LM” is too generic and does not convey the specific enhancements to the distillation plugin, such as per-pair loss configuration and parsing logic that are central to this changeset. Please update the title to clearly reflect the main functionality added, for example “Enable per-layer loss specification and parsing in Megatron distillation plugin” so that reviewers immediately understand the key changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch aanoosheh/sharath-kd-additions

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
modelopt/torch/distill/plugins/megatron.py (5)

62-62: Tighten the type for intermediate_layer_pairs to reflect the new shape.

Use a precise union instead of a variadic tuple for better readability and tooling. This also documents the accepted loss spec values.

Apply this diff:

-    intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list)
+    # Each entry is (student_layer, teacher_layer) or (student_layer, teacher_layer, loss_name)
+    intermediate_layer_pairs: list[tuple[str, str] | tuple[str, str, str]] = field(default_factory=list)

Add the Literal type to constrain the loss names (outside this hunk):

# near other imports
from typing import Literal

Optionally, further constrain to allowed loss specs:

IntermediatePairSpec = tuple[str, str] | tuple[str, str, Literal["cosine", "mse"]]
intermediate_layer_pairs: list[IntermediatePairSpec] = field(default_factory=list)

72-77: Normalize and validate entries in post_init (optional safety).

Current asserts catch length but not type/content. Consider normalizing entries (convert lists to tuples) and early-validating loss spec strings to fail fast on bad YAML.

Example (outside this hunk):

def __post_init__(self):
    ...
    # Normalize potential list entries from YAML into tuples
    self.intermediate_layer_pairs = [tuple(p) for p in self.intermediate_layer_pairs]

78-93: Make loss spec case‑insensitive and fix return type to loss class type.

  • Returning a class, not a callable function, is better typed as type[BaseLoss].
  • Lowercasing the spec avoids surprising user errors.

Apply this diff:

-    def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]:
+    def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, type["BaseLoss"]]:
         """Parse an intermediate entry into a student layer, teacher layer, and loss function."""
         if len(entry) == 3:
-            student_layer, teacher_layer, loss_fn_name = entry
-            if loss_fn_name == "cosine":
+            student_layer, teacher_layer, loss_fn_name = entry
+            loss_fn_name = str(loss_fn_name).lower()
+            if loss_fn_name == "cosine":
                 loss_fn = HiddenStateCosineLoss
-            elif loss_fn_name == "mse":
+            elif loss_fn_name == "mse":
                 loss_fn = MSELoss
             else:
                 raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}")
         else:
             student_layer, teacher_layer = entry
             loss_fn = HiddenStateCosineLoss  # default to cosine loss
         return student_layer, teacher_layer, loss_fn

124-137: Log the chosen loss per pair for traceability.

Including the loss type in the info log helps debugging/config audits.

Apply this diff:

-            if parallel_state.get_tensor_and_context_parallel_rank() == 0:
-                logger.info(
-                    "Distillation: Adding intermediate loss between"
-                    f" `{student_layer}` of student (hidden size {student_cfg.hidden_size}) and"
-                    f" `{teacher_layer}` of teacher (hidden size {teacher_cfg.hidden_size})."
-                )
+            if parallel_state.get_tensor_and_context_parallel_rank() == 0:
+                logger.info(
+                    "Distillation: Adding intermediate loss (%s) between"
+                    " `%s` of student (hidden size %d) and `%s` of teacher (hidden size %d).",
+                    loss_fn.__name__,
+                    student_layer,
+                    student_cfg.hidden_size,
+                    teacher_layer,
+                    teacher_cfg.hidden_size,
+                )

198-204: post_forward returns a tuple but is annotated as Tensor. Update the annotation.

This avoids confusion and type-checking noise.

Apply this diff:

-    def post_forward(
-        self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
-    ) -> Tensor:
+    def post_forward(
+        self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
+    ) -> tuple[Tensor, bool, bool]:
         """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
         loss = loss.transpose(0, 1).contiguous()
         return (loss, tp_reduce, is_sequence_parallel)

Also consider adjusting the docstrings of concrete losses to mention the (loss, tp_reduce, is_sequence_parallel) tuple return.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4ff8fc9 and 281357d.

📒 Files selected for processing (1)
  • modelopt/torch/distill/plugins/megatron.py (5 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/distill/plugins/megatron.py (1)

222-225: MSE reduction changed to mean; confirm intended scale change.

Switching from sum to mean over hidden dim changes the magnitude of the intermediate loss. The dynamic scaling in the balancer may offset this, but it can affect warmup dynamics.

  • Please confirm this matches Megatron-LM behavior you’re aligning with.
  • If previous checkpoints/configs are expected to reproduce, consider bumping kd_loss_scale or noting this in release notes.

Copy link

codecov bot commented Sep 25, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.46%. Comparing base (26c203a) to head (281357d).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #372   +/-   ##
=======================================
  Coverage   73.46%   73.46%           
=======================================
  Files         172      172           
  Lines       17640    17640           
=======================================
+ Hits        12959    12960    +1     
+ Misses       4681     4680    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AAnoosheh AAnoosheh merged commit 0178562 into main Sep 25, 2025
27 checks passed
@AAnoosheh AAnoosheh deleted the aanoosheh/sharath-kd-additions branch September 25, 2025 14:51
kevalmorabia97 pushed a commit that referenced this pull request Sep 25, 2025
kevalmorabia97 pushed a commit that referenced this pull request Sep 25, 2025
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants