Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import asyncio
import hashlib
import inspect
import os
import pickle
import time
from copy import deepcopy
Expand All @@ -21,7 +20,7 @@
from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device
from swift.plugin import Metric
from swift.tuners import Swift
from swift.utils import get_last_valid_indices
from swift.utils import get_generative_reranker_logits
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
EmbeddingResponseData, RequestConfig, random_uuid)
Expand Down Expand Up @@ -346,26 +345,11 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req
elif template.task_type in ('reranker', 'generative_reranker'):
if template.task_type == 'generative_reranker':
# Qwen3-reranker like
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token)
token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token)
attention_mask = inputs.get('attention_mask')
if attention_mask is None:
batch_scores = logits[:, -1, :]
else:
last_valid_indices = get_last_valid_indices(attention_mask)
batch_indices = torch.arange(attention_mask.shape[0], device=logits.device)
batch_scores = logits[batch_indices, last_valid_indices, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1).float()
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
preds = batch_scores[:, 1].exp()
else:
preds = logits.float()
if self.reranker_use_activation:
preds = F.sigmoid(preds)
logits = get_generative_reranker_logits(
template.tokenizer, logits, attention_mask=inputs.get('attention_mask'))
preds = logits.float()
if self.reranker_use_activation:
preds = F.sigmoid(preds)
preds = preds.tolist()
logprobs = [None] * len(preds)
else:
Expand Down
192 changes: 3 additions & 189 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,71 +538,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **k
return loss


def generative_reranker_loss(outputs,
labels,
loss_scale=None,
num_items_in_batch=None,
trainer=None,
attention_mask=None,
**kwargs) -> torch.Tensor:
"""
Generative reranker loss function.
This loss function is designed for generative rerankers that use token probabilities
(e.g., "yes"/"no") to determine relevance scores. It only computes loss on the
last token position for specific tokens.
Args:
outputs: Model outputs containing logits
labels: Binary labels (0/1) for irrelevant/relevant pairs
loss_scale: Not used for generative reranker
num_items_in_batch: Not used for generative reranker
trainer: Trainer instance to access tokenizer
Returns:
torch.Tensor: Cross entropy loss for yes/no classification
"""
if trainer is None:
raise ValueError('trainer is required for generative_reranker_loss to access tokenizer')

logits = outputs.logits
tokenizer = trainer.processing_class

# Get token IDs for positive and negative tokens
# Default to "yes"/"no", but can be configured via environment variables
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')

try:
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
except Exception as e:
raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. "
f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}')

# Extract logits at the last valid (non-padding) token position for each sample
batch_size = logits.shape[0]
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
batch_indices = torch.arange(batch_size, device=logits.device)
last_valid_logits = logits[batch_indices, last_valid_indices, :]

positive_logits = last_valid_logits[:, positive_token_id] # [batch_size]
negative_logits = last_valid_logits[:, negative_token_id] # [batch_size]

# Stack to create binary classification logits
# Shape: [batch_size, 2] where dim=1 represents [negative, positive]
binary_logits = torch.stack([negative_logits, positive_logits], dim=1)

# Convert labels to the correct device and type
binary_labels = labels.to(binary_logits.device).long()

# Compute cross entropy loss
loss_fct = CrossEntropyLoss()
loss = loss_fct(binary_logits, binary_labels)

return loss


def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kwargs) -> torch.Tensor:
"""
List-wise reranker loss function.
Expand Down Expand Up @@ -692,128 +627,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
return total_loss / num_groups


def listwise_generative_reranker_loss(outputs,
labels,
loss_scale=None,
num_items_in_batch=None,
trainer=None,
attention_mask=None,
**kwargs) -> torch.Tensor:
"""
List-wise generative reranker loss function.
This loss function combines the generative reranker approach (using token probabilities)
with list-wise ranking. It groups samples by query based on the pattern where each group
consists of 1 positive document followed by n negative documents, then uses the
probabilities of specific tokens (e.g., "yes"/"no") to perform ranking within each group.
Data format expected:
- labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative
- Each 1 is followed by its corresponding negative documents until the next 1
Environment variables for configuration:
- GENERATIVE_RERANKER_POSITIVE_TOKEN: Token for positive relevance (default: "yes")
- GENERATIVE_RERANKER_NEGATIVE_TOKEN: Token for negative relevance (default: "no")
- LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0)
- LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2)
Args:
outputs: Model outputs containing logits [batch_size, seq_len, vocab_size]
labels: Binary labels (1 for positive, 0 for negative) [batch_size]
loss_scale: Not used for listwise generative reranker
num_items_in_batch: Not used for listwise generative reranker
trainer: Trainer instance to access tokenizer
Returns:
torch.Tensor: Cross entropy loss for ranking classification based on token probabilities
"""
if trainer is None:
raise ValueError('trainer is required for listwise_generative_reranker_loss to access tokenizer')

logits = outputs.logits
tokenizer = trainer.processing_class
labels = labels.float()

# Configuration from environment variables
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
temperature = float(os.environ.get('LISTWISE_RERANKER_TEMPERATURE', '1.0'))
min_group_size = int(os.environ.get('LISTWISE_RERANKER_MIN_GROUP_SIZE', '2'))

# Get token IDs for positive and negative tokens
try:
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
except Exception as e:
raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. "
f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}')

# Extract logits at the last valid (non-padding) token position for each sample
batch_size = logits.shape[0]
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
batch_indices = torch.arange(batch_size, device=logits.device)
last_valid_logits = logits[batch_indices, last_valid_indices, :]

positive_logits = last_valid_logits[:, positive_token_id] # [batch_size]
negative_logits = last_valid_logits[:, negative_token_id] # [batch_size]

logits = F.logsigmoid(positive_logits - negative_logits)

# Find positive sample indices to determine group boundaries
positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1)

if len(positive_indices) == 0:
# No positive samples in this batch, return zero loss
return torch.tensor(0.0, device=logits.device, requires_grad=True)

# Ensure positive_indices is 1D
if positive_indices.dim() == 0:
positive_indices = positive_indices.unsqueeze(0)

total_loss = 0.0
num_groups = 0

for i, pos_idx in enumerate(positive_indices):
# Determine group boundaries
group_start = pos_idx.item()

# Find the end of current group (start of next group or end of batch)
if i + 1 < len(positive_indices):
group_end = positive_indices[i + 1].item()
else:
group_end = len(labels)

# Extract group relevance scores and labels
group_scores = logits[group_start:group_end] # [group_size]
group_labels = labels[group_start:group_end] # [group_size]

# Skip groups that are too small
if len(group_scores) < min_group_size:
continue

# Verify that the first sample in the group is positive
if group_labels[0] != 1:
continue # Skip malformed groups

group_logits = group_scores / temperature

# The positive document is always at index 0 within the group
target = torch.tensor(0, dtype=torch.long, device=logits.device)

# Apply cross-entropy loss: positive document should have highest relevance score
loss_fct = CrossEntropyLoss()
group_loss = loss_fct(group_logits.unsqueeze(0), target.unsqueeze(0))

total_loss += group_loss
num_groups += 1

if num_groups == 0:
return torch.tensor(0.0, device=logits.device, requires_grad=True)

# Return average loss across all groups
return total_loss / num_groups


loss_mapping = {
'cross_entropy': cross_entropy_loss_func, # examples
# embedding
Expand All @@ -823,9 +636,10 @@ def listwise_generative_reranker_loss(outputs,
'infonce': infonce_loss,
# reranker
'reranker': reranker_loss,
'generative_reranker': generative_reranker_loss,
'generative_reranker': reranker_loss,
# Deprecated for compatibility; scheduled for removal
'listwise_reranker': listwise_reranker_loss,
'listwise_generative_reranker': listwise_generative_reranker_loss,
'listwise_generative_reranker': listwise_reranker_loss,
}


Expand Down
46 changes: 1 addition & 45 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,51 +1030,7 @@ def _compute_acc(self, outputs, labels, cu_seqlens=None, attention_mask=None) ->
acc_strategy=args.acc_strategy,
is_encoder_decoder=self.template.is_encoder_decoder,
cu_seqlens=cu_seqlens)
elif task_type == 'generative_reranker':
tokenizer = getattr(self, 'processing_class', None)
if tokenizer is None and getattr(self, 'template', None) is not None:
tokenizer = self.template.tokenizer
if tokenizer is None:
raise RuntimeError('tokenizer not available for generative_reranker acc')

positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')

try:
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
except Exception as e:
logger.warning(f'Failed to convert reranker tokens to ids: {e}')
positive_token_id = None
negative_token_id = None

if isinstance(positive_token_id, int) and isinstance(negative_token_id, int) \
and positive_token_id >= 0 and negative_token_id >= 0:
# Handle right padding by finding the last valid token position
if attention_mask is not None:
# Extract logits at the last valid (non-padding) token position for each sample
batch_size = logits.shape[0]
last_valid_indices = get_last_valid_indices(attention_mask)
batch_indices = torch.arange(batch_size, device=logits.device)
last_valid_logits = logits[batch_indices, last_valid_indices, :]
positive_logits = last_valid_logits[:, positive_token_id]
negative_logits = last_valid_logits[:, negative_token_id]
else:
# Fallback to original behavior if attention_mask is not available
positive_logits = logits[:, -1, positive_token_id]
negative_logits = logits[:, -1, negative_token_id]
if args.loss_type == 'listwise_generative_reranker':
logits = F.logsigmoid(positive_logits - negative_logits)
preds, labels = self._get_listwise_reranker_preds(logits, labels)
else:
preds = (positive_logits > negative_logits).long()
metrics = compute_acc(
preds,
labels.long(),
acc_strategy=args.acc_strategy,
is_encoder_decoder=self.template.is_encoder_decoder,
cu_seqlens=cu_seqlens)
elif task_type == 'reranker':
elif task_type in {'generative_reranker', 'reranker'}:
if logits.dim() == 2:
logits = logits.squeeze(-1)
if args.loss_type == 'listwise_reranker':
Expand Down
Loading
Loading