Skip to content

Commit ae83cd4

Browse files
authored
Update generic_post_process.py
1 parent 5b49f88 commit ae83cd4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]):
6464

6565
if is_all_greedy:
6666
batch_next_token_ids = torch.argmax(logits, -1)
67-
batch_next_token_probs = torch.nn.functional.log_softmax(logits, dim=-1)
67+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
68+
batch_next_token_probs = torch.gather(log_probs, dim=1, index=batch_next_token_ids.view(-1, 1))
6869
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
6970

7071
elif get_env_start_args().sampling_backend == "triton":

0 commit comments

Comments
 (0)