Skip to content

Commit 7dc051a

Browse files
committed
update patch
1 parent b4f0977 commit 7dc051a

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

trinity/common/models/vllm_patch/worker_patch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def _get_prompt_logprobs_dict(
2323
hidden_states: torch.Tensor,
2424
num_scheduled_tokens: dict[str, int],
2525
) -> dict[str, Optional[LogprobsTensors]]:
26+
"""Patched version of _get_prompt_logprobs_dict.
27+
28+
This is a monkey-patched version of `_get_prompt_logprobs_dict` from
29+
`vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions 0.10.0 to 0.11.0).
30+
31+
The original function does not apply temperature scaling to logits when
32+
calculating prompt logprobs, which can lead to incorrect logprob values
33+
when the temperature is not 1.0. This patch adds the missing
34+
temperature scaling.
35+
"""
2636
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
2737
if not num_prompt_logprobs_dict:
2838
return {}
@@ -89,7 +99,7 @@ def _get_prompt_logprobs_dict(
8999

90100
# PATCH START
91101
temp = request.sampling_params.temperature
92-
if temp is None or temp >= 1e-5:
102+
if temp >= 1e-5:
93103
logits.div_(temp)
94104
# PATCH END
95105

0 commit comments

Comments
 (0)