Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
return self

@classmethod
def pack(cls, sequence_context_list: list["SequenceContext"]):
def cat(cls, sequence_context_list: list["SequenceContext"]):
packed_input_ids: list[torch.Tensor] = []
cu_seq_lens_q: list[torch.IntTensor] = []
cu_seq_lens_k: list[torch.IntTensor] = []
Expand Down
32 changes: 32 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
chunks.append(type(self)(**chunk_dict))
return chunks

@classmethod
def cat(cls, chunks: list["BaseLossKwargs"]) -> "BaseLossKwargs":
assert len(chunks) > 0, "chunks must not be empty."

# 收集所有 tensor 字段名(按 chunk[0] 的字段为准;pydantic extra=forbid 也要求字段一致)
first = chunks[0]
tensor_field_names: list[str] = []
for field_name, field_value in first.__dict__.items():
if isinstance(field_value, torch.Tensor):
tensor_field_names.append(field_name)

assert len(tensor_field_names) > 0, "At least one field should be a tensor to cat."

cat_dict: dict[str, torch.Tensor] = {}
for field_name in tensor_field_names:
tensors = [getattr(c, field_name) for c in chunks]
# 与 chunk() 对应:按 dim=1 拼回去
cat_dict[field_name] = torch.cat(tensors, dim=1)

return cls(**cat_dict)


class BaseLossConfig(BaseModel):
model_config = ConfigDict(title="BaseLossConfig", extra="forbid", arbitrary_types_allowed=True)
Expand Down Expand Up @@ -156,3 +177,14 @@ def forward(
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

return loss, (logits, extra_info)

@classmethod
def cat(cls, chunks: list["BaseLossContext"]) -> "BaseLossContext":
assert len(chunks) > 0, "chunks must not be empty."

first = chunks[0]
loss_cfg = first.loss_cfg
loss_kwargs_chunks = [c.loss_kwargs for c in chunks]
loss_kwargs = type(first.loss_kwargs).cat(loss_kwargs_chunks)

return cls(loss_cfg, loss_kwargs)
72 changes: 27 additions & 45 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,19 @@ def _micro_batch_forward(
assert len(seq_ctx_list) == len(loss_ctx_list), "seq_ctx and loss_ctx must have same length"

# Prepare input embeddings for all micro-batches
hidden_states_list: list[torch.Tensor] = []
position_embeddings_list = []

for ctx in seq_ctx_list:
input_ids = ctx.input_ids
position_ids = ctx.position_ids

if input_ids is not None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = ctx.inputs_embeds

# create position embeddings to be shared across the decoder layers
assert position_ids is not None
position_embeddings = self.rotary_emb(hidden_states, position_ids)

hidden_states_list.append(hidden_states)
position_embeddings_list.append(position_embeddings)
if seq_ctx_list[0].input_ids is None:
cat_hidden_states = torch.cat([ctx.inputs_embeds for ctx in seq_ctx_list], dim=1)
else:
cat_input_ids = torch.cat([ctx.input_ids for ctx in seq_ctx_list], dim=1)
cat_hidden_states = self.embed_tokens(cat_input_ids)
cat_position_ids = torch.cat([ctx.position_ids for ctx in seq_ctx_list], dim=1)
cat_position_embeddings = self.rotary_emb(cat_hidden_states, cat_position_ids)
position_embeddings_list = list(
zip(
cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=1),
cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=1),
)
)

# Initialize output containers
output: dict = {}
Expand All @@ -353,28 +348,22 @@ def _micro_batch_forward(

# Process through layers
cat_seq_ctx: SequenceContext | None = None
cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None
cat_hidden_states: torch.Tensor | None = None

moe_forawrd = False
for idx, decoder_layer in self.layers.items():
layer_idx = int(idx)

if layer_idx < self.config.first_k_dense_replace:
if cat_seq_ctx is None:
cat_seq_ctx = SequenceContext.pack(seq_ctx_list)
cos = torch.cat([pe[0] for pe in position_embeddings_list], dim=1)
sin = torch.cat([pe[1] for pe in position_embeddings_list], dim=1)
cat_position_embeddings = (cos, sin)
cat_hidden_states = torch.cat(hidden_states_list, dim=1)
cat_seq_ctx = SequenceContext.cat(seq_ctx_list)
# Dense decoder layer - process concated hidden states
cat_hidden_states = decoder_layer(
cat_hidden_states,
position_embeddings=cat_position_embeddings,
seq_ctx=cat_seq_ctx,
)
else:
if cat_hidden_states is not None and not moe_forawrd:
if not moe_forawrd:
# TODO: `i.clone()` here is weird. However, the current Implementation of
# `async_save_on_cpu` is not friendly with `chunk` op (maybe caused by shared storage? not sure),
# resulting in nan grad norm. So we have to clone the chunked tensors here to make sure each
Expand Down Expand Up @@ -415,25 +404,24 @@ def _micro_batch_forward(
router_weights_list[i][f"layer{idx}"] = router_weights[i]

# Apply final norm to all micro-batches
for i, hidden_states in enumerate(hidden_states_list):
hidden_states_list[i] = self.norm(hidden_states)
cat_hidden_states = torch.cat(hidden_states_list, dim=1)
cat_hidden_states = self.norm(cat_hidden_states)

# Process final outputs for each micro-batch
loss_list: list[torch.Tensor] = []
logits_list: list[torch.Tensor] = []
moe_extra_info = ModelForwardExtraLogInfo()
for hidden_states, loss_ctx_single in zip(hidden_states_list, loss_ctx_list):
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx_single) # type: ignore
loss_list.append(loss)
if logits is not None:
logits_list.append(logits)
if extra_info:
moe_extra_info.append(extra_info)
cat_loss_ctx = CELossContext.cat(loss_ctx_list)
loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) # type: ignore

# Aggregate losses (mean across micro-batches)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Aggregate losses (mean across micro-batches)" but the code calls loss.sum(). If the loss returned by lm_head after concatenating all micro-batches is already a scalar (which is typical), calling .sum() on a scalar tensor is redundant and may be misleading. The comment also says "mean" but the code uses "sum". Consider clarifying whether the loss should be summed or averaged, and update either the code or the comment accordingly.

Suggested change
# Aggregate losses (mean across micro-batches)
# Aggregate loss value (using sum across micro-batches or scalar loss as returned)

Copilot uses AI. Check for mistakes.
output["loss"] = torch.stack(loss_list).sum() if loss_list else None
loss: torch.Tensor
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation loss: torch.Tensor on line 415 is redundant since loss was already assigned on line 412. This type annotation doesn't provide any new information and could be confusing as it appears between the assignment and usage of the variable. Consider removing this redundant annotation.

Suggested change
loss: torch.Tensor

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If loss_ctx is not None, the call method will return a torch.Tensor; otherwise, it will return None. How should I modify the typehint? @HAOCHENYE

output["loss"] = loss.sum()
moe_extra_info = ModelForwardExtraLogInfo()
if extra_info:
moe_extra_info.append(extra_info)
output["extra_info"] = moe_extra_info

# Return logits for all micro-batches
final_logits = logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the variable final_logits

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When constructing MoEModelOutputs, final_logits will be utilized.


# Handle router results for all micro-batches
all_router_logits = []
all_router_weights = []
Expand Down Expand Up @@ -476,12 +464,6 @@ def _micro_batch_forward(

del combined_router_logits

# Return logits for all micro-batches
if all(logits is not None for logits in logits_list):
final_logits = torch.cat(logits_list, dim=0) if logits_list else None
else:
final_logits = None

if self.config.return_router_results or return_router_logits:
# raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
)
rollout_logprobs_list.append(pad_rollout_logprobs)

seq_ctx = SequenceContext.pack(seq_ctx_list)
seq_ctx = SequenceContext.cat(seq_ctx_list)
shifted_labels = torch.cat(label_list, dim=1) # (1, max_len)
advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples)
cu_seq_lens_q = seq_ctx.cu_seq_lens_q
Expand Down
Loading