diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index bab19b8bc..8f81bddb8 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -483,6 +483,10 @@ def __call__( decoded_outputs, out_logprobs, out_tokens = self.detokenize( input_strings, outputs ) + if self.config.return_logprobs and ( + out_logprobs is None or out_logprobs[0] is None + ): + raise ValueError("Logprobs are not returned from the vLLM.") max_tokens_length = max(len(x) for x in prompt_ids) @@ -505,5 +509,5 @@ def __call__( logits=None, tokens=out_tokens[0], padded_prompt_tokens=all_input_ids, - logprobs=out_logprobs[0] if len(out_logprobs[0]) else None, + logprobs=out_logprobs[0] if self.config.return_logprobs else None, )