-
Notifications
You must be signed in to change notification settings - Fork 400
[Enhance] concat micro batch inputs before emb's and lm_head's forward #1409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances performance by concatenating micro-batch inputs before passing them through embedding layers and the language model head, rather than processing each micro-batch separately. This optimization reduces the number of forward passes through these operations.
Key changes:
- Renamed
SequenceContext.pack()toSequenceContext.cat()to better reflect the concatenation operation - Added
cat()class methods toBaseLossKwargs,BaseLossContext, andCELossContextto support concatenating loss contexts - Refactored
_micro_batch_forward()in the MoE model to concatenate inputs before embedding and lm_head operations
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| xtuner/v1/data_proto/sequence_context.py | Renamed method from pack to cat for clarity |
| xtuner/v1/loss/base_loss_ctx.py | Added cat() class methods to support concatenating loss contexts across micro-batches |
| xtuner/v1/rl/base/controller.py | Updated method call from pack to cat following the rename |
| xtuner/v1/model/moe/moe.py | Refactored micro-batch forward to concatenate inputs before embeddings and lm_head, reducing forward pass overhead |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cat_loss_ctx = CELossContext.cat(loss_ctx_list) | ||
| loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) # type: ignore | ||
|
|
||
| # Aggregate losses (mean across micro-batches) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "Aggregate losses (mean across micro-batches)" but the code calls loss.sum(). If the loss returned by lm_head after concatenating all micro-batches is already a scalar (which is typical), calling .sum() on a scalar tensor is redundant and may be misleading. The comment also says "mean" but the code uses "sum". Consider clarifying whether the loss should be summed or averaged, and update either the code or the comment accordingly.
| # Aggregate losses (mean across micro-batches) | |
| # Aggregate loss value (using sum across micro-batches or scalar loss as returned) |
xtuner/v1/model/moe/moe.py
Outdated
|
|
||
| # Aggregate losses (mean across micro-batches) | ||
| output["loss"] = torch.stack(loss_list).sum() if loss_list else None | ||
| loss: torch.Tensor |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation loss: torch.Tensor on line 415 is redundant since loss was already assigned on line 412. This type annotation doesn't provide any new information and could be confusing as it appears between the assignment and usage of the variable. Consider removing this redundant annotation.
| loss: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If loss_ctx is not None, the call method will return a torch.Tensor; otherwise, it will return None. How should I modify the typehint? @HAOCHENYE
Overview (Reviewed by Claude Code)This PR refactors the MoE (Mixture of Experts) model's micro-batch forward pass to improve efficiency by:
Analysis✅ Positive Changes
|
| Aspect | Assessment |
|---|---|
| Code Correctness | |
| Code Style | ✅ Good overall, minor comment language inconsistency |
| Performance | ✅ Should improve through batched operations |
| Test Coverage | ❓ Unknown - tests not visible in diff |
| Security | ✅ No concerns |
Recommendation: Approve with minor revisions - address the variable shadowing and verify numerical equivalence of the LayerNorm change.
xtuner/v1/model/moe/moe.py
Outdated
|
|
||
| # Aggregate losses (mean across micro-batches) | ||
| output["loss"] = torch.stack(loss_list).sum() if loss_list else None | ||
| loss: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.
xtuner/v1/model/moe/moe.py
Outdated
| output["extra_info"] = moe_extra_info | ||
|
|
||
| # Return logits for all micro-batches | ||
| final_logits = logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the variable final_logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When constructing MoEModelOutputs, final_logits will be utilized.
Concat the micro-batch inputs to ensure the all-gather operation is executed only once.