Skip to content

Conversation

@HIT-cwh
Copy link
Collaborator

@HIT-cwh HIT-cwh commented Jan 4, 2026

Concat the micro-batch inputs to ensure the all-gather operation is executed only once.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances performance by concatenating micro-batch inputs before passing them through embedding layers and the language model head, rather than processing each micro-batch separately. This optimization reduces the number of forward passes through these operations.

Key changes:

  • Renamed SequenceContext.pack() to SequenceContext.cat() to better reflect the concatenation operation
  • Added cat() class methods to BaseLossKwargs, BaseLossContext, and CELossContext to support concatenating loss contexts
  • Refactored _micro_batch_forward() in the MoE model to concatenate inputs before embedding and lm_head operations

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 10 comments.

File Description
xtuner/v1/data_proto/sequence_context.py Renamed method from pack to cat for clarity
xtuner/v1/loss/base_loss_ctx.py Added cat() class methods to support concatenating loss contexts across micro-batches
xtuner/v1/rl/base/controller.py Updated method call from pack to cat following the rename
xtuner/v1/model/moe/moe.py Refactored micro-batch forward to concatenate inputs before embeddings and lm_head, reducing forward pass overhead

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

cat_loss_ctx = CELossContext.cat(loss_ctx_list)
loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) # type: ignore

# Aggregate losses (mean across micro-batches)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Aggregate losses (mean across micro-batches)" but the code calls loss.sum(). If the loss returned by lm_head after concatenating all micro-batches is already a scalar (which is typical), calling .sum() on a scalar tensor is redundant and may be misleading. The comment also says "mean" but the code uses "sum". Consider clarifying whether the loss should be summed or averaged, and update either the code or the comment accordingly.

Suggested change
# Aggregate losses (mean across micro-batches)
# Aggregate loss value (using sum across micro-batches or scalar loss as returned)

Copilot uses AI. Check for mistakes.

# Aggregate losses (mean across micro-batches)
output["loss"] = torch.stack(loss_list).sum() if loss_list else None
loss: torch.Tensor
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation loss: torch.Tensor on line 415 is redundant since loss was already assigned on line 412. This type annotation doesn't provide any new information and could be confusing as it appears between the assignment and usage of the variable. Consider removing this redundant annotation.

Suggested change
loss: torch.Tensor

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If loss_ctx is not None, the call method will return a torch.Tensor; otherwise, it will return None. How should I modify the typehint? @HAOCHENYE

@HAOCHENYE
Copy link
Collaborator

HAOCHENYE commented Jan 6, 2026

Overview (Reviewed by Claude Code)

This PR refactors the MoE (Mixture of Experts) model's micro-batch forward pass to improve efficiency by:

  1. Renaming pack() → cat() for SequenceContext (more accurate naming)
  2. Adding a cat() method to BaseLossKwargs and BaseLossContext classes
  3. Major optimization: Consolidating tensor operations in _micro_batch_forward() to reduce redundant computations

Analysis

✅ Positive Changes

  1. Method Naming Improvement (sequence_context.py)
    - Renaming pack → cat is more semantically accurate since the method concatenates tensors using torch.cat
  2. New cat() Methods (base_loss_ctx.py)
    - Well-structured implementation that mirrors the existing chunk() method
    - Correctly uses dim=1 to reverse the chunk operation
    - Proper assertions for input validation
  3. Performance Optimization (moe.py)
    - The old code computed embeddings per micro-batch in a loop, then concatenated later
    - New code concatenates input tensors first, then does a single embedding lookup
    - This reduces kernel launch overhead and improves GPU utilization

⚠️ Potential Issues & Suggestions

  1. Variable Shadowing (moe.py:419)
    loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx)

...

loss: torch.Tensor # This type annotation shadows the variable
output["loss"] = loss.sum()
- The loss: torch.Tensor annotation is redundant and potentially confusing
- Suggestion: Remove the type annotation line
2. Unused Variable (moe.py:427-428)
final_logits = logits
- final_logits is assigned but the logic for handling it was removed. Need to verify it's still used downstream.
3. Logic Change Risk (moe.py:406-408)
cat_hidden_states = torch.cat(hidden_states_list, dim=1)
cat_hidden_states = self.norm(cat_hidden_states)
- This concatenates hidden_states_list again after it was already modified in the MoE branch
- The old code processed normalization per micro-batch; new code does it on concatenated tensor
- Question: Are these mathematically equivalent? LayerNorm across longer sequences vs. per micro-batch may differ slightly
4. Missing Assertion (moe.py:350-351)
if seq_ctx_list[0].input_ids is None:
- No check that seq_ctx_list is non-empty before accessing [0]
- Suggestion: Add assertion or guard
5. Comment Language Consistency (base_loss_ctx.py)
- Comments use Chinese (收集所有 tensor 字段名...)
- Consider using English for consistency with the rest of the codebase

🔍 Questions for Author

  1. Has this been benchmarked to confirm the performance improvement?
  2. Are there unit tests covering the new cat() methods?
  3. Has the numerical equivalence been verified (especially for the LayerNorm change)?

Summary

Aspect Assessment
Code Correctness ⚠️ Minor concerns (variable shadowing, potential semantic difference in norm)
Code Style ✅ Good overall, minor comment language inconsistency
Performance ✅ Should improve through batched operations
Test Coverage ❓ Unknown - tests not visible in diff
Security ✅ No concerns

Recommendation: Approve with minor revisions - address the variable shadowing and verify numerical equivalence of the LayerNorm change.


# Aggregate losses (mean across micro-batches)
output["loss"] = torch.stack(loss_list).sum() if loss_list else None
loss: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.

output["extra_info"] = moe_extra_info

# Return logits for all micro-batches
final_logits = logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the variable final_logits

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When constructing MoEModelOutputs, final_logits will be utilized.

@HIT-cwh HIT-cwh merged commit 2aa70dd into InternLM:main Jan 7, 2026
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants