File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
trinity/common/models/vllm_patch Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -23,6 +23,16 @@ def _get_prompt_logprobs_dict(
2323 hidden_states : torch .Tensor ,
2424 num_scheduled_tokens : dict [str , int ],
2525 ) -> dict [str , Optional [LogprobsTensors ]]:
26+ """Patched version of _get_prompt_logprobs_dict.
27+
28+ This is a monkey-patched version of `_get_prompt_logprobs_dict` from
29+ `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions 0.10.0 to 0.11.0).
30+
31+ The original function does not apply temperature scaling to logits when
32+ calculating prompt logprobs, which can lead to incorrect logprob values
33+ when the temperature is not 1.0. This patch adds the missing
34+ temperature scaling.
35+ """
2636 num_prompt_logprobs_dict = self .input_batch .num_prompt_logprobs
2737 if not num_prompt_logprobs_dict :
2838 return {}
@@ -89,7 +99,7 @@ def _get_prompt_logprobs_dict(
8999
90100 # PATCH START
91101 temp = request .sampling_params .temperature
92- if temp is None or temp >= 1e-5 :
102+ if temp >= 1e-5 :
93103 logits .div_ (temp )
94104 # PATCH END
95105
You can’t perform that action at this time.
0 commit comments