Skip to content

Commit b610cfa

Browse files
tastelikefeetJintao-Huang
authored andcommitted
fix hang (#5114)
1 parent 76ad21d commit b610cfa

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

swift/trainers/trainers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,35 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
6363
return (loss, outputs) if return_outputs else loss
6464

6565

66+
def gather_for_unpadded_tensors(input_data, use_gather_object=False):
67+
from accelerate.utils import gather_object
68+
input_data = gather_object(input_data)
69+
output = []
70+
for _data in input_data:
71+
if len(_data.shape) == 0:
72+
_data = _data.unsqueeze(0)
73+
_data = _data.cpu()
74+
output.append(_data)
75+
if len(output[0].shape) == 1 and output[0].shape[0] > 1:
76+
data = torch.stack(output, dim=0)
77+
else:
78+
data = torch.concat(output, dim=0)
79+
return data
80+
81+
6682
class EmbeddingTrainer(Trainer):
6783

6884
def __init__(self, *args, **kwargs):
6985
super().__init__(*args, **kwargs)
7086
self.compute_metrics = self.calculate_metric
7187
self.preprocess_logits_for_metrics = None
7288
self.label_names = ['labels']
89+
self.gather_function = gather_for_unpadded_tensors
90+
91+
def evaluation_loop(self, *args, **kwargs):
92+
output = super().evaluation_loop(*args, **kwargs)
93+
self.gather_function = gather_for_unpadded_tensors
94+
return output
7395

7496
def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
7597
from swift.plugin.loss import infonce_loss, calculate_paired_metrics, calculate_infonce_metrics
@@ -95,6 +117,7 @@ def __init__(self, *args, **kwargs):
95117
self.preprocess_logits_for_metrics = self._preprocess_generative_reranker_logits
96118
else:
97119
self.preprocess_logits_for_metrics = None
120+
self.gather_function = gather_for_unpadded_tensors
98121

99122
def _preprocess_generative_reranker_logits(self, logits, labels):
100123
"""
@@ -133,6 +156,11 @@ def _preprocess_generative_reranker_logits(self, logits, labels):
133156
# Unexpected shape, return as-is
134157
return logits
135158

159+
def evaluation_loop(self, *args, **kwargs):
160+
output = super().evaluation_loop(*args, **kwargs)
161+
self.gather_function = gather_for_unpadded_tensors
162+
return output
163+
136164
def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
137165
from swift.plugin.loss import (get_loss_func, LossType, calculate_reranker_metrics)
138166

0 commit comments

Comments
 (0)