From 862bb937c203eb07f45c0c682c26a5a1269fff5d Mon Sep 17 00:00:00 2001 From: elichen Date: Thu, 2 Oct 2025 22:04:24 +0800 Subject: [PATCH] Wrap vllm inputs to compatible with VLLM>=0.10.2 --- pyproject.toml | 2 +- src/lighteval/models/vllm/vllm_model.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 411a7b898..4a07beda4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ nanotron = [ "tensorboardX" ] tensorboardX = ["tensorboardX"] -vllm = ["vllm>=0.10.0,<0.10.2", "ray", "more_itertools"] +vllm = ["vllm>=0.10.0", "ray", "more_itertools"] sglang = ["sglang"] quality = ["ruff>=v0.11.0","pre-commit"] tests = ["pytest>=7.4.0","deepdiff","pip>=25.2"] diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 969caf8fa..06c8f39b9 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -30,6 +30,7 @@ import torch from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt from tqdm import tqdm +from vllm.inputs.data import TokensPrompt from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel, ModelConfig @@ -415,6 +416,8 @@ def _generate( generate: bool = True, ) -> list: """Contains the actual logic of the generation.""" + # Wrap inputs with TokensPrompt to make compatible with VLLM >= 0.10.2 + inputs = [TokensPrompt(prompt_token_ids=token_ids) for token_ids in inputs] sampling_params = SamplingParams(**self.config.generation_parameters.to_vllm_dict()) if generate: @@ -437,7 +440,7 @@ def _generate( @ray.remote(num_gpus=self.tensor_parallel_size) def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests): llm = LLM(**model_args) - return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) + return llm.generate(requests, sampling_params=sampling_params) # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers @@ -455,7 +458,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r ] else: outputs = self.model.generate( - prompt_token_ids=inputs, + inputs, sampling_params=sampling_params, use_tqdm=True, )