|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | from torch import nn |
5 | | -from torch.nn import CrossEntropyLoss |
6 | 5 | from transformers.activations import ACT2FN |
7 | 6 | from transformers.cache_utils import Cache, DynamicCache |
8 | 7 | from transformers.generation import GenerationMixin |
|
27 | 26 |
|
28 | 27 | from ....distributed.parallel_state import get_parallel_state |
29 | 28 | from ....distributed.sequence_parallel import slice_position_embedding |
30 | | -from ....ops.loss import causallm_loss_function, seqcls_token_loss_sp_aware |
| 29 | +from ....ops.loss import causallm_loss_function, seqcls_token_loss_function |
31 | 30 | from ....utils import logging |
32 | 31 | from ....utils.import_utils import is_liger_kernel_available |
33 | 32 | from ...module_utils import GradientCheckpointingLayer |
@@ -708,7 +707,6 @@ def forward( |
708 | 707 | cache_position=cache_position, |
709 | 708 | **kwargs, |
710 | 709 | ) |
711 | | - |
712 | 710 | hidden_states = outputs.last_hidden_state |
713 | 711 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss |
714 | 712 | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
@@ -743,7 +741,7 @@ def __init__(self, config): |
743 | 741 | self.num_labels = config.num_labels |
744 | 742 | self.model = Qwen3Model(config) |
745 | 743 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
746 | | - self.loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none") |
| 744 | + self.loss_function = seqcls_token_loss_function |
747 | 745 | # Initialize weights and apply final processing |
748 | 746 | self.post_init() |
749 | 747 |
|
@@ -779,23 +777,7 @@ def forward( |
779 | 777 | logits = self.score(hidden_states) |
780 | 778 |
|
781 | 779 | loss = None |
782 | | - if labels is not None: |
783 | | - # labels are token-level now, shape must match logits tokens |
784 | | - if logits.dim() == 3: |
785 | | - # [B, L, C] -> [B*L, C] |
786 | | - B, L, C = logits.shape |
787 | | - logits_2d = logits.view(B * L, C) |
788 | | - labels_1d = labels.view(B * L).to(logits.device) |
789 | | - elif logits.dim() == 2: |
790 | | - # [T, C] -> [T, C] |
791 | | - logits_2d = logits |
792 | | - labels_1d = labels.view(-1).to(logits.device) |
793 | | - else: |
794 | | - raise ValueError(f"Unexpected logits shape: {logits.shape}") |
795 | | - |
796 | | - ps = get_parallel_state() |
797 | | - sp_group = ps.sp_group if ps.sp_enabled else None |
798 | | - loss = seqcls_token_loss_sp_aware(logits_2d, labels_1d, self.loss_fct, sp_group) |
| 780 | + loss, _ = self.loss_function(hidden_states, self.score.weight, labels) |
799 | 781 |
|
800 | 782 | return SequenceClassifierOutputWithPast( |
801 | 783 | loss=loss, |
|
0 commit comments