-
Notifications
You must be signed in to change notification settings - Fork 170
Copy changes made to Megatron-LM #372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Asha Anoosheh <[email protected]>
WalkthroughBroadened intermediate layer pairing to accept optional per-pair loss spec. Added a parser to map entries to (student, teacher, loss). Configuration loading now uses the parser and applies cosine or MSE per pair. MSE reduction changed to mean. Final loss handling passes is_sequence_parallel through post_forward. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User as Config
participant DC as DistillationConfig
participant D as Distiller
participant L as LossFactory
participant P as PostForward
rect rgba(230,245,255,0.6)
Note over User,DC: Load intermediate layer entries
User->>DC: entries: [(s,t[,loss]), ...]
loop for each entry
DC->>DC: parse_intermediate_entry(entry)
DC-->>D: (student_idx, teacher_idx, loss_kind)
D->>L: build loss_fn(loss_kind)
L-->>D: cosine or mse loss
D->>D: register pair with loss_fn
end
end
rect rgba(240,255,230,0.6)
Note over D,P: Forward & loss aggregation
D->>D: compute per-pair losses<br/>(MSE uses mean reduction)
D->>P: post_forward(total_loss, is_sequence_parallel)
P-->>D: final loss (seq-parallel aware)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
modelopt/torch/distill/plugins/megatron.py (5)
62-62
: Tighten the type for intermediate_layer_pairs to reflect the new shape.Use a precise union instead of a variadic tuple for better readability and tooling. This also documents the accepted loss spec values.
Apply this diff:
- intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) + # Each entry is (student_layer, teacher_layer) or (student_layer, teacher_layer, loss_name) + intermediate_layer_pairs: list[tuple[str, str] | tuple[str, str, str]] = field(default_factory=list)Add the Literal type to constrain the loss names (outside this hunk):
# near other imports from typing import LiteralOptionally, further constrain to allowed loss specs:
IntermediatePairSpec = tuple[str, str] | tuple[str, str, Literal["cosine", "mse"]] intermediate_layer_pairs: list[IntermediatePairSpec] = field(default_factory=list)
72-77
: Normalize and validate entries in post_init (optional safety).Current asserts catch length but not type/content. Consider normalizing entries (convert lists to tuples) and early-validating loss spec strings to fail fast on bad YAML.
Example (outside this hunk):
def __post_init__(self): ... # Normalize potential list entries from YAML into tuples self.intermediate_layer_pairs = [tuple(p) for p in self.intermediate_layer_pairs]
78-93
: Make loss spec case‑insensitive and fix return type to loss class type.
- Returning a class, not a callable function, is better typed as type[BaseLoss].
- Lowercasing the spec avoids surprising user errors.
Apply this diff:
- def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]: + def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, type["BaseLoss"]]: """Parse an intermediate entry into a student layer, teacher layer, and loss function.""" if len(entry) == 3: - student_layer, teacher_layer, loss_fn_name = entry - if loss_fn_name == "cosine": + student_layer, teacher_layer, loss_fn_name = entry + loss_fn_name = str(loss_fn_name).lower() + if loss_fn_name == "cosine": loss_fn = HiddenStateCosineLoss - elif loss_fn_name == "mse": + elif loss_fn_name == "mse": loss_fn = MSELoss else: raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}") else: student_layer, teacher_layer = entry loss_fn = HiddenStateCosineLoss # default to cosine loss return student_layer, teacher_layer, loss_fn
124-137
: Log the chosen loss per pair for traceability.Including the loss type in the info log helps debugging/config audits.
Apply this diff:
- if parallel_state.get_tensor_and_context_parallel_rank() == 0: - logger.info( - "Distillation: Adding intermediate loss between" - f" `{student_layer}` of student (hidden size {student_cfg.hidden_size}) and" - f" `{teacher_layer}` of teacher (hidden size {teacher_cfg.hidden_size})." - ) + if parallel_state.get_tensor_and_context_parallel_rank() == 0: + logger.info( + "Distillation: Adding intermediate loss (%s) between" + " `%s` of student (hidden size %d) and `%s` of teacher (hidden size %d).", + loss_fn.__name__, + student_layer, + student_cfg.hidden_size, + teacher_layer, + teacher_cfg.hidden_size, + )
198-204
: post_forward returns a tuple but is annotated as Tensor. Update the annotation.This avoids confusion and type-checking noise.
Apply this diff:
- def post_forward( - self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False - ) -> Tensor: + def post_forward( + self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False + ) -> tuple[Tensor, bool, bool]: """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" loss = loss.transpose(0, 1).contiguous() return (loss, tp_reduce, is_sequence_parallel)Also consider adjusting the docstrings of concrete losses to mention the (loss, tp_reduce, is_sequence_parallel) tuple return.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/distill/plugins/megatron.py
(5 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/distill/plugins/megatron.py (1)
222-225
: MSE reduction changed to mean; confirm intended scale change.Switching from sum to mean over hidden dim changes the magnitude of the intermediate loss. The dynamic scaling in the balancer may offset this, but it can affect warmup dynamics.
- Please confirm this matches Megatron-LM behavior you’re aligning with.
- If previous checkpoints/configs are expected to reproduce, consider bumping kd_loss_scale or noting this in release notes.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #372 +/- ##
=======================================
Coverage 73.46% 73.46%
=======================================
Files 172 172
Lines 17640 17640
=======================================
+ Hits 12959 12960 +1
+ Misses 4681 4680 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Minor changes
Overview: Copy changes Sharath made to Megatron-LM's distillation code.
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit