Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ nanotron = [
"tensorboardX"
]
tensorboardX = ["tensorboardX"]
vllm = ["vllm>=0.10.0,<0.10.2", "ray", "more_itertools"]
vllm = ["vllm>=0.10.0", "ray", "more_itertools"]
sglang = ["sglang"]
quality = ["ruff>=v0.11.0","pre-commit"]
tests = ["pytest>=7.4.0","deepdiff","pip>=25.2"]
Expand Down
7 changes: 5 additions & 2 deletions src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
from tqdm import tqdm
from vllm.inputs.data import TokensPrompt

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel, ModelConfig
Expand Down Expand Up @@ -415,6 +416,8 @@ def _generate(
generate: bool = True,
) -> list:
"""Contains the actual logic of the generation."""
# Wrap inputs with TokensPrompt to make compatible with VLLM >= 0.10.2
inputs = [TokensPrompt(prompt_token_ids=token_ids) for token_ids in inputs]
sampling_params = SamplingParams(**self.config.generation_parameters.to_vllm_dict())

if generate:
Expand All @@ -437,7 +440,7 @@ def _generate(
@ray.remote(num_gpus=self.tensor_parallel_size)
def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests):
llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
return llm.generate(requests, sampling_params=sampling_params)

# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
Expand All @@ -455,7 +458,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r
]
else:
outputs = self.model.generate(
prompt_token_ids=inputs,
inputs,
sampling_params=sampling_params,
use_tqdm=True,
)
Expand Down
Loading