Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
return self

@classmethod
def pack(cls, sequence_context_list: list["SequenceContext"]):
def cat(cls, sequence_context_list: list["SequenceContext"]) -> "SequenceContext":
packed_input_ids: list[torch.Tensor] = []
cu_seq_lens_q: list[torch.IntTensor] = []
cu_seq_lens_k: list[torch.IntTensor] = []
Expand Down
34 changes: 34 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
chunks.append(type(self)(**chunk_dict))
return chunks

@classmethod
def cat(cls, chunks: list["BaseLossKwargs"]) -> "BaseLossKwargs":
assert len(chunks) > 0, "chunks must not be empty."

# Collect all tensor field names (based on chunk[0]'s fields; pydantic extra=forbid also requires fields to be consistent)
first = chunks[0]
tensor_field_names: list[str] = []
for field_name, field_value in first.__dict__.items():
if isinstance(field_value, torch.Tensor):
tensor_field_names.append(field_name)

assert len(tensor_field_names) > 0, "At least one field should be a tensor to cat."

cat_dict: dict[str, torch.Tensor] = {}
for field_name in tensor_field_names:
tensors = [getattr(c, field_name) for c in chunks]
cat_dict[field_name] = torch.cat(tensors, dim=1)

return cls(**cat_dict)


class BaseLossConfig(BaseModel):
model_config = ConfigDict(title="BaseLossConfig", extra="forbid", arbitrary_types_allowed=True)
Expand All @@ -79,6 +99,9 @@ def loss_ctx_cls(self) -> type["BaseLossContext"]:

LossContextInputItem = TypeVar("LossContextInputItem")

# NOTE: Self type for BaseLossContext subclasses (F-bounded polymorphism)
_BaseLossContextT = TypeVar("_BaseLossContextT", bound="BaseLossContext[Any]")


class BaseLossContext(nn.Module, ABC, Generic[LossContextInputItem]):
def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs):
Expand Down Expand Up @@ -156,3 +179,14 @@ def forward(
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

return loss, (logits, extra_info)

@classmethod
def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT:
assert len(chunks) > 0, "chunks must not be empty."

first = chunks[0]
loss_cfg = first.loss_cfg
loss_kwargs_chunks = [c.loss_kwargs for c in chunks]
loss_kwargs = type(first.loss_kwargs).cat(loss_kwargs_chunks)

return cls(loss_cfg, loss_kwargs)
74 changes: 26 additions & 48 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,19 @@ def _micro_batch_forward(
assert len(seq_ctx_list) == len(loss_ctx_list), "seq_ctx and loss_ctx must have same length"

# Prepare input embeddings for all micro-batches
hidden_states_list: list[torch.Tensor] = []
position_embeddings_list = []

for ctx in seq_ctx_list:
input_ids = ctx.input_ids
position_ids = ctx.position_ids

if input_ids is not None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = ctx.inputs_embeds

# create position embeddings to be shared across the decoder layers
assert position_ids is not None
position_embeddings = self.rotary_emb(hidden_states, position_ids)

hidden_states_list.append(hidden_states)
position_embeddings_list.append(position_embeddings)
if seq_ctx_list[0].input_ids is None:
cat_hidden_states = torch.cat([ctx.inputs_embeds for ctx in seq_ctx_list], dim=1) # type: ignore
else:
cat_input_ids = torch.cat([ctx.input_ids for ctx in seq_ctx_list], dim=1) # type: ignore
cat_hidden_states = self.embed_tokens(cat_input_ids)
cat_position_ids = torch.cat([ctx.position_ids for ctx in seq_ctx_list], dim=1) # type: ignore
cat_position_embeddings = self.rotary_emb(cat_hidden_states, cat_position_ids) # type: ignore
position_embeddings_list = list(
zip(
cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=1),
cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=1),
)
)

# Initialize output containers
output: dict = {}
Expand All @@ -353,35 +348,29 @@ def _micro_batch_forward(

# Process through layers
cat_seq_ctx: SequenceContext | None = None
cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None
cat_hidden_states: torch.Tensor | None = None

moe_forawrd = False
moe_forward = False
for idx, decoder_layer in self.layers.items():
layer_idx = int(idx)

if layer_idx < self.config.first_k_dense_replace:
if cat_seq_ctx is None:
cat_seq_ctx = SequenceContext.pack(seq_ctx_list)
cos = torch.cat([pe[0] for pe in position_embeddings_list], dim=1)
sin = torch.cat([pe[1] for pe in position_embeddings_list], dim=1)
cat_position_embeddings = (cos, sin)
cat_hidden_states = torch.cat(hidden_states_list, dim=1)
cat_seq_ctx = SequenceContext.cat(seq_ctx_list)
# Dense decoder layer - process concated hidden states
cat_hidden_states = decoder_layer(
cat_hidden_states,
position_embeddings=cat_position_embeddings,
seq_ctx=cat_seq_ctx,
)
else:
if cat_hidden_states is not None and not moe_forawrd:
if not moe_forward:
# TODO: `i.clone()` here is weird. However, the current Implementation of
# `async_save_on_cpu` is not friendly with `chunk` op (maybe caused by shared storage? not sure),
# resulting in nan grad norm. So we have to clone the chunked tensors here to make sure each
# hidden state has its own storage. This workaround may introduce extra memory and time cost, and
# should be optimized in the future.
hidden_states_list = [i.clone() for i in cat_hidden_states.chunk(len(seq_ctx_list), dim=1)]
moe_forawrd = True
moe_forward = True

if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1:
with async_save_on_cpu(
Expand Down Expand Up @@ -415,23 +404,18 @@ def _micro_batch_forward(
router_weights_list[i][f"layer{idx}"] = router_weights[i]

# Apply final norm to all micro-batches
for i, hidden_states in enumerate(hidden_states_list):
hidden_states_list[i] = self.norm(hidden_states)
cat_hidden_states = torch.cat(hidden_states_list, dim=1)
cat_hidden_states = self.norm(cat_hidden_states)

# Process final outputs for each micro-batch
loss_list: list[torch.Tensor] = []
logits_list: list[torch.Tensor] = []
moe_extra_info = ModelForwardExtraLogInfo()
for hidden_states, loss_ctx_single in zip(hidden_states_list, loss_ctx_list):
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx_single) # type: ignore
loss_list.append(loss)
if logits is not None:
logits_list.append(logits)
if extra_info:
moe_extra_info.append(extra_info)
cat_loss_ctx = CELossContext.cat(loss_ctx_list)
loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx)

# Aggregate losses (mean across micro-batches)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Aggregate losses (mean across micro-batches)" but the code calls loss.sum(). If the loss returned by lm_head after concatenating all micro-batches is already a scalar (which is typical), calling .sum() on a scalar tensor is redundant and may be misleading. The comment also says "mean" but the code uses "sum". Consider clarifying whether the loss should be summed or averaged, and update either the code or the comment accordingly.

Suggested change
# Aggregate losses (mean across micro-batches)
# Aggregate loss value (using sum across micro-batches or scalar loss as returned)

Copilot uses AI. Check for mistakes.
output["loss"] = torch.stack(loss_list).sum() if loss_list else None
output["loss"] = loss.sum()
moe_extra_info = ModelForwardExtraLogInfo()
if extra_info:
moe_extra_info.append(extra_info)
output["extra_info"] = moe_extra_info

# Handle router results for all micro-batches
Expand Down Expand Up @@ -476,12 +460,6 @@ def _micro_batch_forward(

del combined_router_logits

# Return logits for all micro-batches
if all(logits is not None for logits in logits_list):
final_logits = torch.cat(logits_list, dim=0) if logits_list else None
else:
final_logits = None

if self.config.return_router_results or return_router_logits:
# raise NotImplementedError

Expand All @@ -499,7 +477,7 @@ def _micro_batch_forward(

output["router_logits"] = router_logits_dict

return MoEModelOutputs(**output, logits=final_logits) # type: ignore[typeddict-item]
return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]

def _forward(
self,
Expand Down
27 changes: 23 additions & 4 deletions xtuner/v1/module/lm_head/lm_head.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, TypeAlias
from typing import Any, TypeAlias

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from typing_extensions import overload

from xtuner.v1.loss import CELossContext

Expand All @@ -17,6 +18,16 @@


class LMHead(nn.Linear):
@overload # type: ignore[override]
def forward(
self, hidden_states: HiddenStates, loss_ctx: None = None
) -> tuple[None, tuple[Logits | None, dict[str, Any]]]: ...

@overload # type: ignore[override]
def forward(
self, hidden_states: HiddenStates, loss_ctx: CELossContext
) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ...

def forward( # type: ignore[override]
self, hidden_states: torch.Tensor, loss_ctx: CELossContext | None = None
) -> tuple[Loss | None, tuple[Logits | None, dict[str, Any]]]:
Expand All @@ -37,6 +48,14 @@ def forward( # type: ignore[override]
else:
return loss_ctx.forward(hidden_states, w, b)

__call__: Callable[
["LMHead", HiddenStates, CELossContext | None], tuple[Loss, tuple[Logits | None, dict[str, Any]]]
]
@overload # type: ignore
def __call__(
self, hidden_states: HiddenStates, loss_ctx: None = None
) -> tuple[None, tuple[Logits | None, dict[str, Any]]]: ...

@overload # type: ignore
def __call__(
self, hidden_states: HiddenStates, loss_ctx: CELossContext
) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ...

__call__ = nn.Module.__call__
2 changes: 1 addition & 1 deletion xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
)
rollout_logprobs_list.append(pad_rollout_logprobs)

seq_ctx = SequenceContext.pack(seq_ctx_list)
seq_ctx = SequenceContext.cat(seq_ctx_list)
shifted_labels = torch.cat(label_list, dim=1) # (1, max_len)
advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples)
cu_seq_lens_q = seq_ctx.cu_seq_lens_q
Expand Down
Loading