Skip to content

Commit 528e0c6

Browse files
authored
[reranker] refactor reranker (#7334)
1 parent d31de6e commit 528e0c6

File tree

6 files changed

+46
-325
lines changed

6 files changed

+46
-325
lines changed

swift/llm/infer/infer_engine/pt_engine.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import asyncio
33
import hashlib
44
import inspect
5-
import os
65
import pickle
76
import time
87
from copy import deepcopy
@@ -21,7 +20,7 @@
2120
from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device
2221
from swift.plugin import Metric
2322
from swift.tuners import Swift
24-
from swift.utils import get_last_valid_indices
23+
from swift.utils import get_generative_reranker_logits
2524
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
2625
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
2726
EmbeddingResponseData, RequestConfig, random_uuid)
@@ -346,26 +345,11 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req
346345
elif template.task_type in ('reranker', 'generative_reranker'):
347346
if template.task_type == 'generative_reranker':
348347
# Qwen3-reranker like
349-
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
350-
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
351-
token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token)
352-
token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token)
353-
attention_mask = inputs.get('attention_mask')
354-
if attention_mask is None:
355-
batch_scores = logits[:, -1, :]
356-
else:
357-
last_valid_indices = get_last_valid_indices(attention_mask)
358-
batch_indices = torch.arange(attention_mask.shape[0], device=logits.device)
359-
batch_scores = logits[batch_indices, last_valid_indices, :]
360-
true_vector = batch_scores[:, token_true_id]
361-
false_vector = batch_scores[:, token_false_id]
362-
batch_scores = torch.stack([false_vector, true_vector], dim=1).float()
363-
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
364-
preds = batch_scores[:, 1].exp()
365-
else:
366-
preds = logits.float()
367-
if self.reranker_use_activation:
368-
preds = F.sigmoid(preds)
348+
logits = get_generative_reranker_logits(
349+
template.tokenizer, logits, attention_mask=inputs.get('attention_mask'))
350+
preds = logits.float()
351+
if self.reranker_use_activation:
352+
preds = F.sigmoid(preds)
369353
preds = preds.tolist()
370354
logprobs = [None] * len(preds)
371355
else:

swift/plugin/loss.py

Lines changed: 3 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -538,71 +538,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **k
538538
return loss
539539

540540

541-
def generative_reranker_loss(outputs,
542-
labels,
543-
loss_scale=None,
544-
num_items_in_batch=None,
545-
trainer=None,
546-
attention_mask=None,
547-
**kwargs) -> torch.Tensor:
548-
"""
549-
Generative reranker loss function.
550-
551-
This loss function is designed for generative rerankers that use token probabilities
552-
(e.g., "yes"/"no") to determine relevance scores. It only computes loss on the
553-
last token position for specific tokens.
554-
555-
Args:
556-
outputs: Model outputs containing logits
557-
labels: Binary labels (0/1) for irrelevant/relevant pairs
558-
loss_scale: Not used for generative reranker
559-
num_items_in_batch: Not used for generative reranker
560-
trainer: Trainer instance to access tokenizer
561-
562-
Returns:
563-
torch.Tensor: Cross entropy loss for yes/no classification
564-
"""
565-
if trainer is None:
566-
raise ValueError('trainer is required for generative_reranker_loss to access tokenizer')
567-
568-
logits = outputs.logits
569-
tokenizer = trainer.processing_class
570-
571-
# Get token IDs for positive and negative tokens
572-
# Default to "yes"/"no", but can be configured via environment variables
573-
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
574-
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
575-
576-
try:
577-
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
578-
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
579-
except Exception as e:
580-
raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. "
581-
f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}')
582-
583-
# Extract logits at the last valid (non-padding) token position for each sample
584-
batch_size = logits.shape[0]
585-
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
586-
batch_indices = torch.arange(batch_size, device=logits.device)
587-
last_valid_logits = logits[batch_indices, last_valid_indices, :]
588-
589-
positive_logits = last_valid_logits[:, positive_token_id] # [batch_size]
590-
negative_logits = last_valid_logits[:, negative_token_id] # [batch_size]
591-
592-
# Stack to create binary classification logits
593-
# Shape: [batch_size, 2] where dim=1 represents [negative, positive]
594-
binary_logits = torch.stack([negative_logits, positive_logits], dim=1)
595-
596-
# Convert labels to the correct device and type
597-
binary_labels = labels.to(binary_logits.device).long()
598-
599-
# Compute cross entropy loss
600-
loss_fct = CrossEntropyLoss()
601-
loss = loss_fct(binary_logits, binary_labels)
602-
603-
return loss
604-
605-
606541
def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kwargs) -> torch.Tensor:
607542
"""
608543
List-wise reranker loss function.
@@ -692,128 +627,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
692627
return total_loss / num_groups
693628

694629

695-
def listwise_generative_reranker_loss(outputs,
696-
labels,
697-
loss_scale=None,
698-
num_items_in_batch=None,
699-
trainer=None,
700-
attention_mask=None,
701-
**kwargs) -> torch.Tensor:
702-
"""
703-
List-wise generative reranker loss function.
704-
705-
This loss function combines the generative reranker approach (using token probabilities)
706-
with list-wise ranking. It groups samples by query based on the pattern where each group
707-
consists of 1 positive document followed by n negative documents, then uses the
708-
probabilities of specific tokens (e.g., "yes"/"no") to perform ranking within each group.
709-
710-
Data format expected:
711-
- labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative
712-
- Each 1 is followed by its corresponding negative documents until the next 1
713-
714-
Environment variables for configuration:
715-
- GENERATIVE_RERANKER_POSITIVE_TOKEN: Token for positive relevance (default: "yes")
716-
- GENERATIVE_RERANKER_NEGATIVE_TOKEN: Token for negative relevance (default: "no")
717-
- LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0)
718-
- LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2)
719-
720-
Args:
721-
outputs: Model outputs containing logits [batch_size, seq_len, vocab_size]
722-
labels: Binary labels (1 for positive, 0 for negative) [batch_size]
723-
loss_scale: Not used for listwise generative reranker
724-
num_items_in_batch: Not used for listwise generative reranker
725-
trainer: Trainer instance to access tokenizer
726-
727-
Returns:
728-
torch.Tensor: Cross entropy loss for ranking classification based on token probabilities
729-
"""
730-
if trainer is None:
731-
raise ValueError('trainer is required for listwise_generative_reranker_loss to access tokenizer')
732-
733-
logits = outputs.logits
734-
tokenizer = trainer.processing_class
735-
labels = labels.float()
736-
737-
# Configuration from environment variables
738-
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
739-
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
740-
temperature = float(os.environ.get('LISTWISE_RERANKER_TEMPERATURE', '1.0'))
741-
min_group_size = int(os.environ.get('LISTWISE_RERANKER_MIN_GROUP_SIZE', '2'))
742-
743-
# Get token IDs for positive and negative tokens
744-
try:
745-
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
746-
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
747-
except Exception as e:
748-
raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. "
749-
f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}')
750-
751-
# Extract logits at the last valid (non-padding) token position for each sample
752-
batch_size = logits.shape[0]
753-
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
754-
batch_indices = torch.arange(batch_size, device=logits.device)
755-
last_valid_logits = logits[batch_indices, last_valid_indices, :]
756-
757-
positive_logits = last_valid_logits[:, positive_token_id] # [batch_size]
758-
negative_logits = last_valid_logits[:, negative_token_id] # [batch_size]
759-
760-
logits = F.logsigmoid(positive_logits - negative_logits)
761-
762-
# Find positive sample indices to determine group boundaries
763-
positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1)
764-
765-
if len(positive_indices) == 0:
766-
# No positive samples in this batch, return zero loss
767-
return torch.tensor(0.0, device=logits.device, requires_grad=True)
768-
769-
# Ensure positive_indices is 1D
770-
if positive_indices.dim() == 0:
771-
positive_indices = positive_indices.unsqueeze(0)
772-
773-
total_loss = 0.0
774-
num_groups = 0
775-
776-
for i, pos_idx in enumerate(positive_indices):
777-
# Determine group boundaries
778-
group_start = pos_idx.item()
779-
780-
# Find the end of current group (start of next group or end of batch)
781-
if i + 1 < len(positive_indices):
782-
group_end = positive_indices[i + 1].item()
783-
else:
784-
group_end = len(labels)
785-
786-
# Extract group relevance scores and labels
787-
group_scores = logits[group_start:group_end] # [group_size]
788-
group_labels = labels[group_start:group_end] # [group_size]
789-
790-
# Skip groups that are too small
791-
if len(group_scores) < min_group_size:
792-
continue
793-
794-
# Verify that the first sample in the group is positive
795-
if group_labels[0] != 1:
796-
continue # Skip malformed groups
797-
798-
group_logits = group_scores / temperature
799-
800-
# The positive document is always at index 0 within the group
801-
target = torch.tensor(0, dtype=torch.long, device=logits.device)
802-
803-
# Apply cross-entropy loss: positive document should have highest relevance score
804-
loss_fct = CrossEntropyLoss()
805-
group_loss = loss_fct(group_logits.unsqueeze(0), target.unsqueeze(0))
806-
807-
total_loss += group_loss
808-
num_groups += 1
809-
810-
if num_groups == 0:
811-
return torch.tensor(0.0, device=logits.device, requires_grad=True)
812-
813-
# Return average loss across all groups
814-
return total_loss / num_groups
815-
816-
817630
loss_mapping = {
818631
'cross_entropy': cross_entropy_loss_func, # examples
819632
# embedding
@@ -823,9 +636,10 @@ def listwise_generative_reranker_loss(outputs,
823636
'infonce': infonce_loss,
824637
# reranker
825638
'reranker': reranker_loss,
826-
'generative_reranker': generative_reranker_loss,
639+
'generative_reranker': reranker_loss,
640+
# Deprecated for compatibility; scheduled for removal
827641
'listwise_reranker': listwise_reranker_loss,
828-
'listwise_generative_reranker': listwise_generative_reranker_loss,
642+
'listwise_generative_reranker': listwise_reranker_loss,
829643
}
830644

831645

swift/trainers/mixin.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,51 +1030,7 @@ def _compute_acc(self, outputs, labels, cu_seqlens=None, attention_mask=None) ->
10301030
acc_strategy=args.acc_strategy,
10311031
is_encoder_decoder=self.template.is_encoder_decoder,
10321032
cu_seqlens=cu_seqlens)
1033-
elif task_type == 'generative_reranker':
1034-
tokenizer = getattr(self, 'processing_class', None)
1035-
if tokenizer is None and getattr(self, 'template', None) is not None:
1036-
tokenizer = self.template.tokenizer
1037-
if tokenizer is None:
1038-
raise RuntimeError('tokenizer not available for generative_reranker acc')
1039-
1040-
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
1041-
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
1042-
1043-
try:
1044-
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
1045-
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
1046-
except Exception as e:
1047-
logger.warning(f'Failed to convert reranker tokens to ids: {e}')
1048-
positive_token_id = None
1049-
negative_token_id = None
1050-
1051-
if isinstance(positive_token_id, int) and isinstance(negative_token_id, int) \
1052-
and positive_token_id >= 0 and negative_token_id >= 0:
1053-
# Handle right padding by finding the last valid token position
1054-
if attention_mask is not None:
1055-
# Extract logits at the last valid (non-padding) token position for each sample
1056-
batch_size = logits.shape[0]
1057-
last_valid_indices = get_last_valid_indices(attention_mask)
1058-
batch_indices = torch.arange(batch_size, device=logits.device)
1059-
last_valid_logits = logits[batch_indices, last_valid_indices, :]
1060-
positive_logits = last_valid_logits[:, positive_token_id]
1061-
negative_logits = last_valid_logits[:, negative_token_id]
1062-
else:
1063-
# Fallback to original behavior if attention_mask is not available
1064-
positive_logits = logits[:, -1, positive_token_id]
1065-
negative_logits = logits[:, -1, negative_token_id]
1066-
if args.loss_type == 'listwise_generative_reranker':
1067-
logits = F.logsigmoid(positive_logits - negative_logits)
1068-
preds, labels = self._get_listwise_reranker_preds(logits, labels)
1069-
else:
1070-
preds = (positive_logits > negative_logits).long()
1071-
metrics = compute_acc(
1072-
preds,
1073-
labels.long(),
1074-
acc_strategy=args.acc_strategy,
1075-
is_encoder_decoder=self.template.is_encoder_decoder,
1076-
cu_seqlens=cu_seqlens)
1077-
elif task_type == 'reranker':
1033+
elif task_type in {'generative_reranker', 'reranker'}:
10781034
if logits.dim() == 2:
10791035
logits = logits.squeeze(-1)
10801036
if args.loss_type == 'listwise_reranker':

0 commit comments

Comments
 (0)