@@ -63,13 +63,35 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
6363 return (loss , outputs ) if return_outputs else loss
6464
6565
66+ def gather_for_unpadded_tensors (input_data , use_gather_object = False ):
67+ from accelerate .utils import gather_object
68+ input_data = gather_object (input_data )
69+ output = []
70+ for _data in input_data :
71+ if len (_data .shape ) == 0 :
72+ _data = _data .unsqueeze (0 )
73+ _data = _data .cpu ()
74+ output .append (_data )
75+ if len (output [0 ].shape ) == 1 and output [0 ].shape [0 ] > 1 :
76+ data = torch .stack (output , dim = 0 )
77+ else :
78+ data = torch .concat (output , dim = 0 )
79+ return data
80+
81+
6682class EmbeddingTrainer (Trainer ):
6783
6884 def __init__ (self , * args , ** kwargs ):
6985 super ().__init__ (* args , ** kwargs )
7086 self .compute_metrics = self .calculate_metric
7187 self .preprocess_logits_for_metrics = None
7288 self .label_names = ['labels' ]
89+ self .gather_function = gather_for_unpadded_tensors
90+
91+ def evaluation_loop (self , * args , ** kwargs ):
92+ output = super ().evaluation_loop (* args , ** kwargs )
93+ self .gather_function = gather_for_unpadded_tensors
94+ return output
7395
7496 def calculate_metric (self , eval_prediction : EvalPrediction ) -> Dict [str , float ]:
7597 from swift .plugin .loss import infonce_loss , calculate_paired_metrics , calculate_infonce_metrics
@@ -95,6 +117,7 @@ def __init__(self, *args, **kwargs):
95117 self .preprocess_logits_for_metrics = self ._preprocess_generative_reranker_logits
96118 else :
97119 self .preprocess_logits_for_metrics = None
120+ self .gather_function = gather_for_unpadded_tensors
98121
99122 def _preprocess_generative_reranker_logits (self , logits , labels ):
100123 """
@@ -133,6 +156,11 @@ def _preprocess_generative_reranker_logits(self, logits, labels):
133156 # Unexpected shape, return as-is
134157 return logits
135158
159+ def evaluation_loop (self , * args , ** kwargs ):
160+ output = super ().evaluation_loop (* args , ** kwargs )
161+ self .gather_function = gather_for_unpadded_tensors
162+ return output
163+
136164 def calculate_metric (self , eval_prediction : EvalPrediction ) -> Dict [str , float ]:
137165 from swift .plugin .loss import (get_loss_func , LossType , calculate_reranker_metrics )
138166
0 commit comments