Skip to content

Commit 532b83f

Browse files
authored
fix the compute of accuracy for reranker (#6089)
1 parent 5de484f commit 532b83f

File tree

1 file changed

+70
-27
lines changed

1 file changed

+70
-27
lines changed

swift/trainers/mixin.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -870,33 +870,76 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
870870

871871
def _compute_acc(self, outputs, labels) -> None:
872872
args = self.args
873-
preds = outputs.logits.argmax(dim=-1)
874-
if self.template.sequence_parallel_size > 1:
875-
from swift.trainers.sequence_parallel import sequence_parallel
876-
# Gather preds and labels across the sp group
877-
if isinstance(preds, np.ndarray):
878-
preds = torch.from_numpy(preds).to(get_current_device())
879-
if isinstance(labels, np.ndarray):
880-
labels = torch.from_numpy(labels).to(get_current_device())
881-
assert labels.shape[1] == preds.shape[1]
882-
883-
if sequence_parallel.rp_world_size > 1:
884-
position_ids = sequence_parallel.real_position_ids
885-
position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids)
886-
else:
887-
position_ids = None
888-
preds_output = sequence_parallel.gather(preds, dim=1, position_ids=position_ids)
889-
labels_output = sequence_parallel.gather(labels, dim=1, position_ids=position_ids)
890-
# roll back to fit compute_acc
891-
labels_output = torch.roll(labels_output, shifts=1, dims=1)
892-
preds = preds_output
893-
labels = labels_output.int()
894-
895-
metrics = compute_acc(
896-
preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
897-
mode = 'train' if self.model.training else 'eval'
898-
for k, v in metrics.items():
899-
self.custom_metrics[mode][k].update(v)
873+
logits = outputs.logits
874+
metrics = None
875+
if getattr(args, 'loss_type', None) in {'generative_reranker', 'listwise_generative_reranker'} \
876+
and logits is not None and logits.dim() == 3:
877+
tokenizer = getattr(self, 'processing_class', None)
878+
if tokenizer is None and getattr(self, 'template', None) is not None:
879+
tokenizer = self.template.tokenizer
880+
if tokenizer is None:
881+
raise RuntimeError('tokenizer not available for generative_reranker acc')
882+
883+
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
884+
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
885+
886+
try:
887+
positive_token_id = tokenizer.convert_tokens_to_ids(positive_token)
888+
negative_token_id = tokenizer.convert_tokens_to_ids(negative_token)
889+
except Exception as e:
890+
logger.warning(f'Failed to convert reranker tokens to ids: {e}')
891+
positive_token_id = None
892+
negative_token_id = None
893+
894+
if isinstance(positive_token_id, int) and isinstance(negative_token_id, int) \
895+
and positive_token_id >= 0 and negative_token_id >= 0:
896+
positive_logits = logits[:, -1, positive_token_id]
897+
negative_logits = logits[:, -1, negative_token_id]
898+
binary_preds = (positive_logits > negative_logits).long()
899+
metrics = compute_acc(
900+
binary_preds,
901+
labels.long(),
902+
acc_strategy=args.acc_strategy,
903+
is_encoder_decoder=self.template.is_encoder_decoder)
904+
elif logits.dim() == 1 or (logits.dim() == 2 and logits.size(-1) == 1):
905+
if logits.dim() == 2:
906+
logits = logits.squeeze(-1)
907+
binary_preds = (logits > 0).long()
908+
metrics = compute_acc(
909+
binary_preds,
910+
labels.long(),
911+
acc_strategy=args.acc_strategy,
912+
is_encoder_decoder=self.template.is_encoder_decoder)
913+
else:
914+
preds = logits.argmax(dim=-1)
915+
if self.template.sequence_parallel_size > 1:
916+
from swift.trainers.sequence_parallel import sequence_parallel
917+
# Gather preds and labels across the sp group
918+
if isinstance(preds, np.ndarray):
919+
preds = torch.from_numpy(preds).to(get_current_device())
920+
if isinstance(labels, np.ndarray):
921+
labels = torch.from_numpy(labels).to(get_current_device())
922+
assert labels.shape[1] == preds.shape[1]
923+
924+
if sequence_parallel.rp_world_size > 1:
925+
position_ids = sequence_parallel.real_position_ids
926+
position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids)
927+
else:
928+
position_ids = None
929+
preds_output = sequence_parallel.gather(preds, dim=1, position_ids=position_ids)
930+
labels_output = sequence_parallel.gather(labels, dim=1, position_ids=position_ids)
931+
# roll back to fit compute_acc
932+
labels_output = torch.roll(labels_output, shifts=1, dims=1)
933+
preds = preds_output
934+
labels = labels_output.int()
935+
936+
metrics = compute_acc(
937+
preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
938+
939+
if metrics:
940+
mode = 'train' if self.model.training else 'eval'
941+
for k, v in metrics.items():
942+
self.custom_metrics[mode][k].update(v)
900943

901944
@torch.no_grad()
902945
def _evalscope_eval(self):

0 commit comments

Comments
 (0)