Skip to content

Conversation

AAnoosheh
Copy link
Contributor

@AAnoosheh AAnoosheh commented Sep 17, 2025

What does this PR do?

Type of change: ?
new feature

Overview: ?
Remove restrictions to allow DistillationModel.compute_kd_loss() to be called during Megatron validation

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No Y
  • Did you write any new necessary tests?: N
  • Did you add or update any necessary documentation?: N
  • Did you update Changelog?: N

Additional Information

Summary by CodeRabbit

  • Bug Fixes

    • Ensures consistent distillation behavior across training and evaluation by always computing teacher and student outputs, improving stability with pipeline parallelism.
  • Refactor

    • Simplified distillation configuration: the adjustment helper now accepts flexible keyword arguments, reducing coupling to specific parameters and improving extensibility. This may require minor updates to custom integrations that called the previous interface.

Signed-off-by: Asha Anoosheh <[email protected]>
@AAnoosheh AAnoosheh requested a review from a team as a code owner September 17, 2025 19:54
Copy link

coderabbitai bot commented Sep 17, 2025

Walkthrough

The Megatron distillation plugin’s forward path now always computes both teacher (no_grad) and student outputs, with concatenation on non-final pipeline stages. The distillation shape-adjustment helper switches to a **kwargs signature and simplified gating, forwarding config/group via **kwargs to shape utilities.

Changes

Cohort / File(s) Summary of Changes
Megatron distillation plugin
modelopt/torch/distill/plugins/megatron.py
- _forward: removed early non-training shortcut; always runs teacher and student; concatenates outputs except on last pipeline stage.
- get_tensor_shapes_adjust_fn_for_distillation: signature changed to (model, **kwargs); gating simplified (no forward_only); forwards **kwargs to shape helpers.
- Updated internal calls to get_tensor_shapes(...) to pass config/group via **kwargs.
- Duplicated legacy signature blocks consolidated to the new **kwargs approach.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Plugin as DistillPlugin
  participant Teacher
  participant Student
  participant Pipeline as PipelineStage

  Note over Plugin,Pipeline: Previous flow (before)
  Caller->>Plugin: _forward(inputs, training=False)
  alt not training
    Plugin->>Student: forward(inputs)
    Student-->>Plugin: student_out
    Plugin-->>Caller: student_out
  else training / other
    Plugin->>Teacher: forward(inputs) (no_grad)
    Teacher-->>Plugin: teacher_out
    Plugin->>Student: forward(inputs)
    Student-->>Plugin: student_out
    alt not last stage
      Plugin->>Pipeline: concat(teacher_out, student_out)
      Pipeline-->>Caller: combined_out
    else last stage
      Plugin-->>Caller: student_out
    end
  end
Loading
sequenceDiagram
  autonumber
  participant Caller
  participant Plugin as DistillPlugin
  participant Teacher
  participant Student
  participant Pipeline as PipelineStage

  Note over Plugin,Pipeline: New flow (after)
  Caller->>Plugin: _forward(inputs, training any)
  par Compute teacher (no_grad)
    Plugin->>Teacher: forward(inputs)
    Teacher-->>Plugin: teacher_out
  and Compute student
    Plugin->>Student: only_student_forward(inputs)
    Student-->>Plugin: student_out
  end
  alt not last pipeline stage
    Plugin->>Pipeline: concat(teacher_out, student_out)
    Pipeline-->>Caller: combined_out
  else last pipeline stage
    Plugin-->>Caller: student_out
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I thump my paws in tidy delight,
Teacher and student run side by side in flight—
No shortcut burrows, straight through the light.
Shapes pass softly through kwargs dew,
Pipelines braid their outputs anew.
Hop, hop—distill the night to chewier chew! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Allow KD loss in val mode for MLM plugin" directly summarizes the primary change—enabling knowledge-distillation loss during validation—and names the affected component (MLM plugin). This aligns with the PR objectives and the changes in modelopt/torch/distill/plugins/megatron.py that remove the training-only shortcut and ensure teacher and student outputs are produced in validation. The title is concise, specific, and not vague or noisy.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch aanoosheh/enhance-megatron-kd

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@AAnoosheh AAnoosheh self-assigned this Sep 17, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/megatron.py (1)

578-582: Whitelist kwargs passed to get_tensor_shapes; avoid accidental TypeError.

Passing raw **kwargs through can leak unexpected keys to get_tensor_shapes(...). Filter to the known set before forwarding. Also, both teacher shape calls are identical; compute once and reuse.

Example fix (insert before the calls):

allowed = {"seq_length", "micro_batch_size", "decoder_seq_length", "forward_only"}
_shape_kwargs = {k: v for k, v in kwargs.items() if k in allowed}

Then update the calls:

teacher_shapes = get_tensor_shapes(
    config=teacher_config, tp_group=tp_group, cp_group=cp_group, **_shape_kwargs
)
teacher_recv_tensor_shapes = teacher_shapes
teacher_send_tensor_shapes = teacher_shapes
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 682bf6d and b7d2576.

📒 Files selected for processing (1)
  • modelopt/torch/distill/plugins/megatron.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/distill/plugins/megatron.py (1)

558-561: Gating change LGTM; confirm assumptions for non‑interleaved PP.

Dropping forward_only from the gate matches the goal (enable KD in validation). Please confirm this is only exercised for non‑interleaved PP and that eval runs with PP>1 don’t hit unexpected shape paths with VPP configured.

Comment on lines +550 to 552
model: torch.nn.Module | list[torch.nn.Module], **kwargs
) -> Callable | None:
"""Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid breaking API: keep positional args (backward‑compat) and funnel into kwargs.

Switching to only **kwargs will break any existing positional call sites with TypeError: too many positional arguments. Keep the old parameters (defaulted/optional), merge them into kwargs, and deprecate later.

Apply this diff to the signature:

-def get_tensor_shapes_adjust_fn_for_distillation(
-    model: torch.nn.Module | list[torch.nn.Module], **kwargs
-) -> Callable | None:
+def get_tensor_shapes_adjust_fn_for_distillation(
+    model: torch.nn.Module | list[torch.nn.Module],
+    seq_length: int | None = None,
+    micro_batch_size: int | None = None,
+    decoder_seq_length: int | None = None,
+    forward_only: bool | None = None,
+    **kwargs,
+) -> Callable | None:

Add this merge shim at the top of the function body:

# Back‑compat: funnel explicit args into kwargs if provided.
if seq_length is not None:
    kwargs.setdefault("seq_length", seq_length)
if micro_batch_size is not None:
    kwargs.setdefault("micro_batch_size", micro_batch_size)
if decoder_seq_length is not None:
    kwargs.setdefault("decoder_seq_length", decoder_seq_length)
if forward_only is not None:
    kwargs.setdefault("forward_only", forward_only)
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 550-552, the function
signature was changed to only accept **kwargs which breaks backward
compatibility for callers using positional parameters; restore the original
explicit parameters (seq_length, micro_batch_size, decoder_seq_length,
forward_only) as optional/defaulted parameters in the signature, and at the top
of the function body add a back-compat shim that funnels any provided explicit
args into kwargs (using kwargs.setdefault) so existing positional call sites
continue to work; mark these explicit params as deprecated in a comment for
future removal.

@AAnoosheh AAnoosheh enabled auto-merge (squash) September 17, 2025 20:02
Copy link

codecov bot commented Sep 17, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.82%. Comparing base (d406aa1) to head (b7d2576).
⚠️ Report is 8 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #331      +/-   ##
==========================================
- Coverage   73.88%   73.82%   -0.06%     
==========================================
  Files         172      172              
  Lines       17444    17438       -6     
==========================================
- Hits        12888    12874      -14     
- Misses       4556     4564       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AAnoosheh AAnoosheh merged commit 5db7169 into main Sep 17, 2025
22 of 23 checks passed
@AAnoosheh AAnoosheh deleted the aanoosheh/enhance-megatron-kd branch September 17, 2025 21:52
yeyu-nvidia pushed a commit that referenced this pull request Sep 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants