diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index e5cea63da9..990135888a 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -2,7 +2,6 @@ import asyncio import hashlib import inspect -import os import pickle import time from copy import deepcopy @@ -21,7 +20,7 @@ from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device from swift.plugin import Metric from swift.tuners import Swift -from swift.utils import get_last_valid_indices +from swift.utils import get_generative_reranker_logits from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, EmbeddingResponseData, RequestConfig, random_uuid) @@ -346,26 +345,11 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req elif template.task_type in ('reranker', 'generative_reranker'): if template.task_type == 'generative_reranker': # Qwen3-reranker like - positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') - negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') - token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token) - token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token) - attention_mask = inputs.get('attention_mask') - if attention_mask is None: - batch_scores = logits[:, -1, :] - else: - last_valid_indices = get_last_valid_indices(attention_mask) - batch_indices = torch.arange(attention_mask.shape[0], device=logits.device) - batch_scores = logits[batch_indices, last_valid_indices, :] - true_vector = batch_scores[:, token_true_id] - false_vector = batch_scores[:, token_false_id] - batch_scores = torch.stack([false_vector, true_vector], dim=1).float() - batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) - preds = batch_scores[:, 1].exp() - else: - preds = logits.float() - if self.reranker_use_activation: - preds = F.sigmoid(preds) + logits = get_generative_reranker_logits( + template.tokenizer, logits, attention_mask=inputs.get('attention_mask')) + preds = logits.float() + if self.reranker_use_activation: + preds = F.sigmoid(preds) preds = preds.tolist() logprobs = [None] * len(preds) else: diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 5820f519b1..3c9265fbbb 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -538,71 +538,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **k return loss -def generative_reranker_loss(outputs, - labels, - loss_scale=None, - num_items_in_batch=None, - trainer=None, - attention_mask=None, - **kwargs) -> torch.Tensor: - """ - Generative reranker loss function. - - This loss function is designed for generative rerankers that use token probabilities - (e.g., "yes"/"no") to determine relevance scores. It only computes loss on the - last token position for specific tokens. - - Args: - outputs: Model outputs containing logits - labels: Binary labels (0/1) for irrelevant/relevant pairs - loss_scale: Not used for generative reranker - num_items_in_batch: Not used for generative reranker - trainer: Trainer instance to access tokenizer - - Returns: - torch.Tensor: Cross entropy loss for yes/no classification - """ - if trainer is None: - raise ValueError('trainer is required for generative_reranker_loss to access tokenizer') - - logits = outputs.logits - tokenizer = trainer.processing_class - - # Get token IDs for positive and negative tokens - # Default to "yes"/"no", but can be configured via environment variables - positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') - negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') - - try: - positive_token_id = tokenizer.convert_tokens_to_ids(positive_token) - negative_token_id = tokenizer.convert_tokens_to_ids(negative_token) - except Exception as e: - raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. " - f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') - - # Extract logits at the last valid (non-padding) token position for each sample - batch_size = logits.shape[0] - last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask) - batch_indices = torch.arange(batch_size, device=logits.device) - last_valid_logits = logits[batch_indices, last_valid_indices, :] - - positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] - negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] - - # Stack to create binary classification logits - # Shape: [batch_size, 2] where dim=1 represents [negative, positive] - binary_logits = torch.stack([negative_logits, positive_logits], dim=1) - - # Convert labels to the correct device and type - binary_labels = labels.to(binary_logits.device).long() - - # Compute cross entropy loss - loss_fct = CrossEntropyLoss() - loss = loss_fct(binary_logits, binary_labels) - - return loss - - def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kwargs) -> torch.Tensor: """ List-wise reranker loss function. @@ -692,128 +627,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch= return total_loss / num_groups -def listwise_generative_reranker_loss(outputs, - labels, - loss_scale=None, - num_items_in_batch=None, - trainer=None, - attention_mask=None, - **kwargs) -> torch.Tensor: - """ - List-wise generative reranker loss function. - - This loss function combines the generative reranker approach (using token probabilities) - with list-wise ranking. It groups samples by query based on the pattern where each group - consists of 1 positive document followed by n negative documents, then uses the - probabilities of specific tokens (e.g., "yes"/"no") to perform ranking within each group. - - Data format expected: - - labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative - - Each 1 is followed by its corresponding negative documents until the next 1 - - Environment variables for configuration: - - GENERATIVE_RERANKER_POSITIVE_TOKEN: Token for positive relevance (default: "yes") - - GENERATIVE_RERANKER_NEGATIVE_TOKEN: Token for negative relevance (default: "no") - - LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0) - - LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2) - - Args: - outputs: Model outputs containing logits [batch_size, seq_len, vocab_size] - labels: Binary labels (1 for positive, 0 for negative) [batch_size] - loss_scale: Not used for listwise generative reranker - num_items_in_batch: Not used for listwise generative reranker - trainer: Trainer instance to access tokenizer - - Returns: - torch.Tensor: Cross entropy loss for ranking classification based on token probabilities - """ - if trainer is None: - raise ValueError('trainer is required for listwise_generative_reranker_loss to access tokenizer') - - logits = outputs.logits - tokenizer = trainer.processing_class - labels = labels.float() - - # Configuration from environment variables - positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') - negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') - temperature = float(os.environ.get('LISTWISE_RERANKER_TEMPERATURE', '1.0')) - min_group_size = int(os.environ.get('LISTWISE_RERANKER_MIN_GROUP_SIZE', '2')) - - # Get token IDs for positive and negative tokens - try: - positive_token_id = tokenizer.convert_tokens_to_ids(positive_token) - negative_token_id = tokenizer.convert_tokens_to_ids(negative_token) - except Exception as e: - raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. " - f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') - - # Extract logits at the last valid (non-padding) token position for each sample - batch_size = logits.shape[0] - last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask) - batch_indices = torch.arange(batch_size, device=logits.device) - last_valid_logits = logits[batch_indices, last_valid_indices, :] - - positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] - negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] - - logits = F.logsigmoid(positive_logits - negative_logits) - - # Find positive sample indices to determine group boundaries - positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1) - - if len(positive_indices) == 0: - # No positive samples in this batch, return zero loss - return torch.tensor(0.0, device=logits.device, requires_grad=True) - - # Ensure positive_indices is 1D - if positive_indices.dim() == 0: - positive_indices = positive_indices.unsqueeze(0) - - total_loss = 0.0 - num_groups = 0 - - for i, pos_idx in enumerate(positive_indices): - # Determine group boundaries - group_start = pos_idx.item() - - # Find the end of current group (start of next group or end of batch) - if i + 1 < len(positive_indices): - group_end = positive_indices[i + 1].item() - else: - group_end = len(labels) - - # Extract group relevance scores and labels - group_scores = logits[group_start:group_end] # [group_size] - group_labels = labels[group_start:group_end] # [group_size] - - # Skip groups that are too small - if len(group_scores) < min_group_size: - continue - - # Verify that the first sample in the group is positive - if group_labels[0] != 1: - continue # Skip malformed groups - - group_logits = group_scores / temperature - - # The positive document is always at index 0 within the group - target = torch.tensor(0, dtype=torch.long, device=logits.device) - - # Apply cross-entropy loss: positive document should have highest relevance score - loss_fct = CrossEntropyLoss() - group_loss = loss_fct(group_logits.unsqueeze(0), target.unsqueeze(0)) - - total_loss += group_loss - num_groups += 1 - - if num_groups == 0: - return torch.tensor(0.0, device=logits.device, requires_grad=True) - - # Return average loss across all groups - return total_loss / num_groups - - loss_mapping = { 'cross_entropy': cross_entropy_loss_func, # examples # embedding @@ -823,9 +636,10 @@ def listwise_generative_reranker_loss(outputs, 'infonce': infonce_loss, # reranker 'reranker': reranker_loss, - 'generative_reranker': generative_reranker_loss, + 'generative_reranker': reranker_loss, + # Deprecated for compatibility; scheduled for removal 'listwise_reranker': listwise_reranker_loss, - 'listwise_generative_reranker': listwise_generative_reranker_loss, + 'listwise_generative_reranker': listwise_reranker_loss, } diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index fa2a3f4863..a8f8f2b3e1 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -1030,51 +1030,7 @@ def _compute_acc(self, outputs, labels, cu_seqlens=None, attention_mask=None) -> acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder, cu_seqlens=cu_seqlens) - elif task_type == 'generative_reranker': - tokenizer = getattr(self, 'processing_class', None) - if tokenizer is None and getattr(self, 'template', None) is not None: - tokenizer = self.template.tokenizer - if tokenizer is None: - raise RuntimeError('tokenizer not available for generative_reranker acc') - - positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') - negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') - - try: - positive_token_id = tokenizer.convert_tokens_to_ids(positive_token) - negative_token_id = tokenizer.convert_tokens_to_ids(negative_token) - except Exception as e: - logger.warning(f'Failed to convert reranker tokens to ids: {e}') - positive_token_id = None - negative_token_id = None - - if isinstance(positive_token_id, int) and isinstance(negative_token_id, int) \ - and positive_token_id >= 0 and negative_token_id >= 0: - # Handle right padding by finding the last valid token position - if attention_mask is not None: - # Extract logits at the last valid (non-padding) token position for each sample - batch_size = logits.shape[0] - last_valid_indices = get_last_valid_indices(attention_mask) - batch_indices = torch.arange(batch_size, device=logits.device) - last_valid_logits = logits[batch_indices, last_valid_indices, :] - positive_logits = last_valid_logits[:, positive_token_id] - negative_logits = last_valid_logits[:, negative_token_id] - else: - # Fallback to original behavior if attention_mask is not available - positive_logits = logits[:, -1, positive_token_id] - negative_logits = logits[:, -1, negative_token_id] - if args.loss_type == 'listwise_generative_reranker': - logits = F.logsigmoid(positive_logits - negative_logits) - preds, labels = self._get_listwise_reranker_preds(logits, labels) - else: - preds = (positive_logits > negative_logits).long() - metrics = compute_acc( - preds, - labels.long(), - acc_strategy=args.acc_strategy, - is_encoder_decoder=self.template.is_encoder_decoder, - cu_seqlens=cu_seqlens) - elif task_type == 'reranker': + elif task_type in {'generative_reranker', 'reranker'}: if logits.dim() == 2: logits = logits.squeeze(-1) if args.loss_type == 'listwise_reranker': diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index d1b383e582..00e72559cf 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -16,7 +16,8 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import is_peft_available -from swift.utils import JsonlWriter, Serializer, gc_collect, get_logger, unwrap_model_for_generation +from swift.utils import (JsonlWriter, Serializer, gc_collect, get_generative_reranker_logits, get_logger, + unwrap_model_for_generation) from .arguments import Seq2SeqTrainingArguments, TrainingArguments from .mixin import DataLoaderMixin, SwiftMixin from .utils import per_token_loss_func, per_token_loss_func_sp @@ -135,69 +136,17 @@ def __init__(self, *args, **kwargs): self.compute_metrics = self.calculate_metric self.label_names = ['labels'] - # Set up preprocess_logits_for_metrics to reduce memory usage for generative reranker - if self.args.loss_type in {'generative_reranker', 'listwise_generative_reranker'}: - self.preprocess_logits_for_metrics = self._preprocess_generative_reranker_logits - else: - self.preprocess_logits_for_metrics = None + self.preprocess_logits_for_metrics = None self.gather_function = gather_for_unpadded_tensors - def _preprocess_generative_reranker_logits(self, logits, labels): - """ - Preprocess logits for generative reranker to reduce memory usage. - Extract only the yes/no token logits at the last valid (non -100) timestep - for each sample, avoiding padded timesteps created by multi-GPU gather. - """ - - # Get token IDs for positive and negative tokens - positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') - negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') - - tokenizer = getattr(self, 'processing_class', None) - if tokenizer is None: - # Fallback: return full logits if tokenizer not available - return logits - - try: - positive_token_id = tokenizer.convert_tokens_to_ids(positive_token) - negative_token_id = tokenizer.convert_tokens_to_ids(negative_token) - except Exception: - # Fallback: return full logits if token conversion fails - return logits - - # Extract only the yes/no token logits from the last non -100 position per sample - # Shapes: logits [batch, seq_len, vocab] - if len(logits.shape) == 3: - positive_logits = logits[:, :, positive_token_id] - negative_logits = logits[:, :, negative_token_id] - logits = positive_logits - negative_logits - return logits - else: - # Unexpected shape, return as-is - return logits - def evaluation_loop(self, *args, **kwargs): output = super().evaluation_loop(*args, **kwargs) self.gather_function = gather_for_unpadded_tensors return output def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]: - import numpy as np from swift.plugin.loss import calculate_reranker_metrics - input_ids = eval_prediction.inputs - logits = eval_prediction.predictions - labels = eval_prediction.label_ids - - if self.template.padding_free: - logits = logits[:, -1] - else: - if logits.ndim == 2 and logits.shape[1] > 1: - pad_token_id = self.tokenizer.pad_token_id - valid_mask = (input_ids != pad_token_id) & (input_ids != -100) - last_valid_indices = valid_mask[:, ::-1].argmax(axis=1) - last_valid_indices = input_ids.shape[1] - 1 - last_valid_indices - logits = logits[np.arange(logits.shape[0]), last_valid_indices] - return calculate_reranker_metrics(logits, labels) + return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if inputs.get('attention_mask') is None and self.template.padding_side != 'left': @@ -205,20 +154,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Check if we have a custom loss function if self.compute_loss_func is not None: # Get labels and compute outputs - labels = inputs.get('labels') - if labels is not None: - labels = inputs.pop('labels') - + labels = inputs.pop('labels', None) outputs = model(**inputs) + if self.args.task_type == 'generative_reranker': + outputs.logits = get_generative_reranker_logits( + self.tokenizer, outputs.logits, attention_mask=inputs.get('attention_mask')) if labels is not None: # Call custom loss function - loss = self.compute_loss_func( - outputs, - labels, - num_items_in_batch=num_items_in_batch, - trainer=self, - attention_mask=inputs.get('attention_mask')) + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) else: # Fallback to model's loss loss = outputs.loss @@ -227,7 +171,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = loss / self.args.gradient_accumulation_steps if labels is not None: - self._compute_acc(outputs, labels, attention_mask=inputs.get('attention_mask')) + self._compute_acc(outputs, labels) return (loss, outputs) if return_outputs else loss else: diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 86ea0c4525..06b5f6e399 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -12,9 +12,10 @@ from .torch_utils import (Serializer, activate_parameters, check_shared_disk, disable_safe_ddp_context_use_barrier, empty_cache, find_all_linears, find_embedding, find_layers, find_norm, freeze_parameters, gc_collect, get_cu_seqlens_from_position_ids, get_current_device, get_device, - get_device_count, get_last_valid_indices, get_model_parameter_info, get_n_params_grads, - init_process_group, safe_ddp_context, seed_worker, set_default_ddp_config, set_device, - show_layers, time_synchronize, unwrap_model_for_generation) + get_device_count, get_generative_reranker_logits, get_last_valid_indices, + get_model_parameter_info, get_n_params_grads, init_process_group, safe_ddp_context, + seed_worker, set_default_ddp_config, set_device, show_layers, time_synchronize, + unwrap_model_for_generation) from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, format_time, get_env_args, get_modules_to_not_convert, import_external_file, json_parse_to_dict, lower_bound, parse_args, patch_getattr, read_multi_line, remove_response, retry_decorator, diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index 4b412e2dd9..936ff149fb 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -534,3 +534,25 @@ def unwrap_model_for_generation( add_hooks(model) else: yield unwrapped_model + + +def get_generative_reranker_logits(tokenizer, logits, attention_mask=None): + positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') + negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') + positive_token_id = tokenizer.convert_tokens_to_ids(positive_token) + negative_token_id = tokenizer.convert_tokens_to_ids(negative_token) + + # Handle right padding by finding the last valid token position + if attention_mask is not None: + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + last_valid_indices = get_last_valid_indices(attention_mask) + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + positive_logits = last_valid_logits[:, positive_token_id] + negative_logits = last_valid_logits[:, negative_token_id] + else: + # Fallback to original behavior if attention_mask is not available + positive_logits = logits[:, -1, positive_token_id] + negative_logits = logits[:, -1, negative_token_id] + return (positive_logits - negative_logits)[:, None]