|
21 | 21 | from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device |
22 | 22 | from swift.plugin import Metric |
23 | 23 | from swift.tuners import Swift |
| 24 | +from swift.utils import get_last_valid_indices |
24 | 25 | from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, |
25 | 26 | ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse, |
26 | 27 | EmbeddingResponseData, RequestConfig, random_uuid) |
@@ -349,7 +350,13 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req |
349 | 350 | negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') |
350 | 351 | token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token) |
351 | 352 | token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token) |
352 | | - batch_scores = logits[:, -1, :] |
| 353 | + attention_mask = inputs.get('attention_mask') |
| 354 | + if attention_mask is None: |
| 355 | + batch_scores = logits[:, -1, :] |
| 356 | + else: |
| 357 | + last_valid_indices = get_last_valid_indices(attention_mask) |
| 358 | + batch_indices = torch.arange(attention_mask.shape[0], device=logits.device) |
| 359 | + batch_scores = logits[batch_indices, last_valid_indices, :] |
353 | 360 | true_vector = batch_scores[:, token_true_id] |
354 | 361 | false_vector = batch_scores[:, token_false_id] |
355 | 362 | batch_scores = torch.stack([false_vector, true_vector], dim=1).float() |
|
0 commit comments