Skip to content

Commit bd6e36f

Browse files
committed
update loss function
1 parent e45423c commit bd6e36f

File tree

2 files changed

+42
-38
lines changed

2 files changed

+42
-38
lines changed

veomni/models/transformers/qwen3/modeling_qwen3.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from torch import nn
5-
from torch.nn import CrossEntropyLoss
65
from transformers.activations import ACT2FN
76
from transformers.cache_utils import Cache, DynamicCache
87
from transformers.generation import GenerationMixin
@@ -27,7 +26,7 @@
2726

2827
from ....distributed.parallel_state import get_parallel_state
2928
from ....distributed.sequence_parallel import slice_position_embedding
30-
from ....ops.loss import causallm_loss_function, seqcls_token_loss_sp_aware
29+
from ....ops.loss import causallm_loss_function, seqcls_token_loss_function
3130
from ....utils import logging
3231
from ....utils.import_utils import is_liger_kernel_available
3332
from ...module_utils import GradientCheckpointingLayer
@@ -708,7 +707,6 @@ def forward(
708707
cache_position=cache_position,
709708
**kwargs,
710709
)
711-
712710
hidden_states = outputs.last_hidden_state
713711
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
714712
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
@@ -743,7 +741,7 @@ def __init__(self, config):
743741
self.num_labels = config.num_labels
744742
self.model = Qwen3Model(config)
745743
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
746-
self.loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none")
744+
self.loss_function = seqcls_token_loss_function
747745
# Initialize weights and apply final processing
748746
self.post_init()
749747

@@ -779,23 +777,7 @@ def forward(
779777
logits = self.score(hidden_states)
780778

781779
loss = None
782-
if labels is not None:
783-
# labels are token-level now, shape must match logits tokens
784-
if logits.dim() == 3:
785-
# [B, L, C] -> [B*L, C]
786-
B, L, C = logits.shape
787-
logits_2d = logits.view(B * L, C)
788-
labels_1d = labels.view(B * L).to(logits.device)
789-
elif logits.dim() == 2:
790-
# [T, C] -> [T, C]
791-
logits_2d = logits
792-
labels_1d = labels.view(-1).to(logits.device)
793-
else:
794-
raise ValueError(f"Unexpected logits shape: {logits.shape}")
795-
796-
ps = get_parallel_state()
797-
sp_group = ps.sp_group if ps.sp_enabled else None
798-
loss = seqcls_token_loss_sp_aware(logits_2d, labels_1d, self.loss_fct, sp_group)
780+
loss, _ = self.loss_function(hidden_states, self.score.weight, labels)
799781

800782
return SequenceClassifierOutputWithPast(
801783
loss=loss,

veomni/ops/loss.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional
22

33
import torch
4-
import torch.distributed as dist
54
import torch.nn as nn
65
import torch.nn.functional as F
76

@@ -93,21 +92,44 @@ def causallm_loss_function(
9392
return loss, logits
9493

9594

96-
def seqcls_token_loss_sp_aware(
97-
logits: torch.Tensor, # [N, C]
98-
labels: torch.Tensor, # [N]
99-
loss_fct: nn.Module,
100-
sp_group,
95+
def seqcls_token_loss_function(
96+
hidden_states: torch.Tensor,
97+
weight: torch.Tensor,
98+
labels: torch.Tensor,
99+
num_items_in_batch: Optional[int] = None,
101100
ignore_index: int = -100,
101+
shift_labels: Optional[torch.Tensor] = None,
102+
**kwargs,
102103
) -> torch.Tensor:
103-
# local sum loss
104-
# CrossEntropyLoss(reduction="none") + mask + sum
105-
per = loss_fct(logits, labels) # [N] if reduction="none"
106-
valid = labels != ignore_index
107-
loss_sum = (per * valid).sum()
108-
cnt = valid.sum().to(dtype=loss_sum.dtype)
109-
110-
if sp_group is not None:
111-
dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM, group=sp_group)
112-
dist.all_reduce(cnt, op=dist.ReduceOp.SUM, group=sp_group)
113-
return loss_sum / cnt.clamp_min(1.0)
104+
# We don't use shift_labels
105+
assert shift_labels is None
106+
107+
loss = None
108+
logits = None
109+
110+
if labels is None:
111+
logits = F.linear(hidden_states, weight)
112+
return loss, logits
113+
114+
sp_enabled = get_parallel_state().sp_enabled
115+
116+
# Flatten the labels and hidden_states
117+
labels = labels.view(-1)
118+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
119+
120+
# Calculate loss
121+
if fused_linear_cross_entropy is not None: # use kernels
122+
if is_seed_kernels_available():
123+
loss = fused_linear_cross_entropy(hidden_states, weight, labels, ignore_index=ignore_index)
124+
elif is_liger_kernel_available():
125+
loss = fused_linear_cross_entropy(weight, hidden_states, labels)
126+
else:
127+
logits = F.linear(hidden_states, weight).float()
128+
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
129+
130+
# Reduce loss when using sp
131+
if sp_enabled:
132+
num_valid_tokens = (labels != ignore_index).sum()
133+
loss = reduce_sequence_parallel_loss(loss, num_valid_tokens)
134+
135+
return loss, logits

0 commit comments

Comments
 (0)