@@ -538,71 +538,6 @@ def reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **k
538538 return loss
539539
540540
541- def generative_reranker_loss (outputs ,
542- labels ,
543- loss_scale = None ,
544- num_items_in_batch = None ,
545- trainer = None ,
546- attention_mask = None ,
547- ** kwargs ) -> torch .Tensor :
548- """
549- Generative reranker loss function.
550-
551- This loss function is designed for generative rerankers that use token probabilities
552- (e.g., "yes"/"no") to determine relevance scores. It only computes loss on the
553- last token position for specific tokens.
554-
555- Args:
556- outputs: Model outputs containing logits
557- labels: Binary labels (0/1) for irrelevant/relevant pairs
558- loss_scale: Not used for generative reranker
559- num_items_in_batch: Not used for generative reranker
560- trainer: Trainer instance to access tokenizer
561-
562- Returns:
563- torch.Tensor: Cross entropy loss for yes/no classification
564- """
565- if trainer is None :
566- raise ValueError ('trainer is required for generative_reranker_loss to access tokenizer' )
567-
568- logits = outputs .logits
569- tokenizer = trainer .processing_class
570-
571- # Get token IDs for positive and negative tokens
572- # Default to "yes"/"no", but can be configured via environment variables
573- positive_token = os .environ .get ('GENERATIVE_RERANKER_POSITIVE_TOKEN' , 'yes' )
574- negative_token = os .environ .get ('GENERATIVE_RERANKER_NEGATIVE_TOKEN' , 'no' )
575-
576- try :
577- positive_token_id = tokenizer .convert_tokens_to_ids (positive_token )
578- negative_token_id = tokenizer .convert_tokens_to_ids (negative_token )
579- except Exception as e :
580- raise ValueError (f"Failed to convert tokens '{ positive_token } '/'{ negative_token } ' to IDs. "
581- f'Please check if these tokens exist in the tokenizer vocabulary. Error: { e } ' )
582-
583- # Extract logits at the last valid (non-padding) token position for each sample
584- batch_size = logits .shape [0 ]
585- last_valid_indices = - 1 if attention_mask is None else get_last_valid_indices (attention_mask )
586- batch_indices = torch .arange (batch_size , device = logits .device )
587- last_valid_logits = logits [batch_indices , last_valid_indices , :]
588-
589- positive_logits = last_valid_logits [:, positive_token_id ] # [batch_size]
590- negative_logits = last_valid_logits [:, negative_token_id ] # [batch_size]
591-
592- # Stack to create binary classification logits
593- # Shape: [batch_size, 2] where dim=1 represents [negative, positive]
594- binary_logits = torch .stack ([negative_logits , positive_logits ], dim = 1 )
595-
596- # Convert labels to the correct device and type
597- binary_labels = labels .to (binary_logits .device ).long ()
598-
599- # Compute cross entropy loss
600- loss_fct = CrossEntropyLoss ()
601- loss = loss_fct (binary_logits , binary_labels )
602-
603- return loss
604-
605-
606541def listwise_reranker_loss (outputs , labels , loss_scale = None , num_items_in_batch = None , ** kwargs ) -> torch .Tensor :
607542 """
608543 List-wise reranker loss function.
@@ -692,128 +627,6 @@ def listwise_reranker_loss(outputs, labels, loss_scale=None, num_items_in_batch=
692627 return total_loss / num_groups
693628
694629
695- def listwise_generative_reranker_loss (outputs ,
696- labels ,
697- loss_scale = None ,
698- num_items_in_batch = None ,
699- trainer = None ,
700- attention_mask = None ,
701- ** kwargs ) -> torch .Tensor :
702- """
703- List-wise generative reranker loss function.
704-
705- This loss function combines the generative reranker approach (using token probabilities)
706- with list-wise ranking. It groups samples by query based on the pattern where each group
707- consists of 1 positive document followed by n negative documents, then uses the
708- probabilities of specific tokens (e.g., "yes"/"no") to perform ranking within each group.
709-
710- Data format expected:
711- - labels: [1, 0, 0, 0, 1, 0, 0, ...] where 1 indicates positive, 0 indicates negative
712- - Each 1 is followed by its corresponding negative documents until the next 1
713-
714- Environment variables for configuration:
715- - GENERATIVE_RERANKER_POSITIVE_TOKEN: Token for positive relevance (default: "yes")
716- - GENERATIVE_RERANKER_NEGATIVE_TOKEN: Token for negative relevance (default: "no")
717- - LISTWISE_RERANKER_TEMPERATURE: Temperature for softmax (default: 1.0)
718- - LISTWISE_RERANKER_MIN_GROUP_SIZE: Minimum group size to include (default: 2)
719-
720- Args:
721- outputs: Model outputs containing logits [batch_size, seq_len, vocab_size]
722- labels: Binary labels (1 for positive, 0 for negative) [batch_size]
723- loss_scale: Not used for listwise generative reranker
724- num_items_in_batch: Not used for listwise generative reranker
725- trainer: Trainer instance to access tokenizer
726-
727- Returns:
728- torch.Tensor: Cross entropy loss for ranking classification based on token probabilities
729- """
730- if trainer is None :
731- raise ValueError ('trainer is required for listwise_generative_reranker_loss to access tokenizer' )
732-
733- logits = outputs .logits
734- tokenizer = trainer .processing_class
735- labels = labels .float ()
736-
737- # Configuration from environment variables
738- positive_token = os .environ .get ('GENERATIVE_RERANKER_POSITIVE_TOKEN' , 'yes' )
739- negative_token = os .environ .get ('GENERATIVE_RERANKER_NEGATIVE_TOKEN' , 'no' )
740- temperature = float (os .environ .get ('LISTWISE_RERANKER_TEMPERATURE' , '1.0' ))
741- min_group_size = int (os .environ .get ('LISTWISE_RERANKER_MIN_GROUP_SIZE' , '2' ))
742-
743- # Get token IDs for positive and negative tokens
744- try :
745- positive_token_id = tokenizer .convert_tokens_to_ids (positive_token )
746- negative_token_id = tokenizer .convert_tokens_to_ids (negative_token )
747- except Exception as e :
748- raise ValueError (f"Failed to convert tokens '{ positive_token } '/'{ negative_token } ' to IDs. "
749- f'Please check if these tokens exist in the tokenizer vocabulary. Error: { e } ' )
750-
751- # Extract logits at the last valid (non-padding) token position for each sample
752- batch_size = logits .shape [0 ]
753- last_valid_indices = - 1 if attention_mask is None else get_last_valid_indices (attention_mask )
754- batch_indices = torch .arange (batch_size , device = logits .device )
755- last_valid_logits = logits [batch_indices , last_valid_indices , :]
756-
757- positive_logits = last_valid_logits [:, positive_token_id ] # [batch_size]
758- negative_logits = last_valid_logits [:, negative_token_id ] # [batch_size]
759-
760- logits = F .logsigmoid (positive_logits - negative_logits )
761-
762- # Find positive sample indices to determine group boundaries
763- positive_indices = torch .nonzero (labels == 1 , as_tuple = False ).squeeze (- 1 )
764-
765- if len (positive_indices ) == 0 :
766- # No positive samples in this batch, return zero loss
767- return torch .tensor (0.0 , device = logits .device , requires_grad = True )
768-
769- # Ensure positive_indices is 1D
770- if positive_indices .dim () == 0 :
771- positive_indices = positive_indices .unsqueeze (0 )
772-
773- total_loss = 0.0
774- num_groups = 0
775-
776- for i , pos_idx in enumerate (positive_indices ):
777- # Determine group boundaries
778- group_start = pos_idx .item ()
779-
780- # Find the end of current group (start of next group or end of batch)
781- if i + 1 < len (positive_indices ):
782- group_end = positive_indices [i + 1 ].item ()
783- else :
784- group_end = len (labels )
785-
786- # Extract group relevance scores and labels
787- group_scores = logits [group_start :group_end ] # [group_size]
788- group_labels = labels [group_start :group_end ] # [group_size]
789-
790- # Skip groups that are too small
791- if len (group_scores ) < min_group_size :
792- continue
793-
794- # Verify that the first sample in the group is positive
795- if group_labels [0 ] != 1 :
796- continue # Skip malformed groups
797-
798- group_logits = group_scores / temperature
799-
800- # The positive document is always at index 0 within the group
801- target = torch .tensor (0 , dtype = torch .long , device = logits .device )
802-
803- # Apply cross-entropy loss: positive document should have highest relevance score
804- loss_fct = CrossEntropyLoss ()
805- group_loss = loss_fct (group_logits .unsqueeze (0 ), target .unsqueeze (0 ))
806-
807- total_loss += group_loss
808- num_groups += 1
809-
810- if num_groups == 0 :
811- return torch .tensor (0.0 , device = logits .device , requires_grad = True )
812-
813- # Return average loss across all groups
814- return total_loss / num_groups
815-
816-
817630loss_mapping = {
818631 'cross_entropy' : cross_entropy_loss_func , # examples
819632 # embedding
@@ -823,9 +636,10 @@ def listwise_generative_reranker_loss(outputs,
823636 'infonce' : infonce_loss ,
824637 # reranker
825638 'reranker' : reranker_loss ,
826- 'generative_reranker' : generative_reranker_loss ,
639+ 'generative_reranker' : reranker_loss ,
640+ # Deprecated for compatibility; scheduled for removal
827641 'listwise_reranker' : listwise_reranker_loss ,
828- 'listwise_generative_reranker' : listwise_generative_reranker_loss ,
642+ 'listwise_generative_reranker' : listwise_reranker_loss ,
829643}
830644
831645
0 commit comments