Skip to content

Commit d4e6f59

Browse files
authored
Fix loading of vllm model from files (#533)
* commit * commit * Update src/lighteval/main_vllm.py * commit * change doc * change doc * change doc
1 parent 86f6225 commit d4e6f59

File tree

5 files changed

+62
-19
lines changed

5 files changed

+62
-19
lines changed

docs/source/use-vllm-as-backend.mdx

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,31 @@ lighteval vllm \
2929
"leaderboard|truthfulqa:mc|0|0"
3030
```
3131

32-
Available arguments for `vllm` can be found in the `VLLMModelConfig`:
33-
34-
- **pretrained** (str): HuggingFace Hub model ID name or the path to a pre-trained model to load.
35-
- **gpu_memory_utilisation** (float): The fraction of GPU memory to use.
36-
- **revision** (str): The revision of the model.
37-
- **dtype** (str, None): The data type to use for the model.
38-
- **tensor_parallel_size** (int): The number of tensor parallel units to use.
39-
- **data_parallel_size** (int): The number of data parallel units to use.
40-
- **max_model_length** (int): The maximum length of the model.
41-
- **swap_space** (int): The CPU swap space size (GiB) per GPU.
42-
- **seed** (int): The seed to use for the model.
43-
- **trust_remote_code** (bool): Whether to trust remote code during model loading.
44-
- **add_special_tokens** (bool): Whether to add special tokens to the input sequences.
45-
- **multichoice_continuations_start_space** (bool): Whether to add a space at the start of each continuation in multichoice generation.
32+
## Use a config file
33+
34+
For more advanced configurations, you can use a config file for the model.
35+
An example of a config file is shown below and can be found at `examples/model_configs/vllm_model_config.yaml`.
36+
37+
```bash
38+
lighteval vllm \
39+
"examples/model_configs/vllm_model_config.yaml" \
40+
"leaderboard|truthfulqa:mc|0|0"
41+
```
42+
43+
```yaml
44+
model: # Model specific parameters
45+
base_params:
46+
model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # Model args that you would pass in the command line
47+
generation: # Generation specific parameters
48+
temperature: 0.3
49+
repetition_penalty: 1.0
50+
frequency_penalty: 0.0
51+
presence_penalty: 0.0
52+
seed: 42
53+
top_k: 0
54+
min_p: 0.0
55+
top_p: 0.9
56+
```
4657
4758
> [!WARNING]
4859
> In the case of OOM issues, you might need to reduce the context size of the
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
model:
2+
base_params:
3+
model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ...
4+
generation:
5+
temperature: 0.3
6+
repetition_penalty: 1.0
7+
frequency_penalty: 0.0
8+
presence_penalty: 0.0
9+
seed: 42
10+
top_k: -1
11+
min_p: 0.0
12+
top_p: 0.9
13+
max_new_tokens: 100
14+
stop_tokens: ["<EOS>", "<PAD>"]

src/lighteval/main_vllm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,13 @@ def vllm(
133133
if model_args.endswith(".yaml"):
134134
with open(model_args, "r") as f:
135135
config = yaml.safe_load(f)["model"]
136+
model_args = config["base_params"]["model_args"]
136137
generation_parameters = GenerationParameters.from_dict(config)
137-
model_config = VLLMModelConfig(config, generation_parameters=generation_parameters)
138-
139138
else:
140-
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
141-
model_config = VLLMModelConfig(**model_args_dict)
139+
generation_parameters = GenerationParameters()
140+
141+
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
142+
model_config = VLLMModelConfig(**model_args_dict, generation_parameters=generation_parameters)
142143

143144
pipeline = Pipeline(
144145
tasks=tasks,

src/lighteval/models/model_input.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def from_dict(cls, config_dict: dict):
5959
"""
6060
return GenerationParameters(**config_dict.get("generation", {}))
6161

62+
def to_vllm_dict(self) -> dict:
63+
"""Selects relevant generation and sampling parameters for vllm models.
64+
Doc: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html
65+
66+
Returns:
67+
dict: The parameters to create a vllm.SamplingParams in the model config.
68+
"""
69+
sampling_params_to_vllm_naming = {
70+
"max_new_tokens": "max_tokens",
71+
"min_new_tokens": "min_tokens",
72+
"stop_tokens": "stop",
73+
}
74+
75+
# Task specific sampling params to set in model: n, best_of, use_beam_search
76+
# Generation specific params to set in model: logprobs, prompt_logprobs
77+
return {sampling_params_to_vllm_naming.get(k, k): v for k, v in asdict(self).items() if v is not None}
78+
6279
def to_vllm_openai_dict(self) -> dict:
6380
"""Selects relevant generation and sampling parameters for vllm and openai models.
6481
Doc: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html

src/lighteval/models/vllm/vllm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
self.precision = _get_dtype(config.dtype, config=self._config)
130130

131131
self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
132-
self.sampling_params = SamplingParams(**config.generation_parameters.to_vllm_openai_dict())
132+
self.sampling_params = SamplingParams(**config.generation_parameters.to_vllm_dict())
133133
self.pairwise_tokenization = config.pairwise_tokenization
134134

135135
@property

0 commit comments

Comments
 (0)