Skip to content

Commit fd479ee

Browse files
authored
Add extended task for LiveCodeBench codegeneration (#548)
```shell lighteval vllm \ "pretrained=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B,dtype=float16,data_parallel_size=4,max_model_length=32768,gpu_memory_utilisation=0.8,generation_parameters={temperature: 0.7}" \ "extended|lcb:codegeneration|0|0" \ --use-chat-template ```
1 parent d6de1fe commit fd479ee

File tree

10 files changed

+976
-6
lines changed

10 files changed

+976
-6
lines changed

docs/source/use-vllm-as-backend.mdx

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,32 @@ model: # Model specific parameters
5858
> [!WARNING]
5959
> In the case of OOM issues, you might need to reduce the context size of the
6060
> model as well as reduce the `gpu_memory_utilization` parameter.
61+
62+
63+
## Dynamically changing the metric configuration
64+
65+
For special kinds of metrics like `Pass@K` or LiveCodeBench's `codegen` metric, you may need to pass specific values like the number of
66+
generations. This can be done in the `yaml` file in the following way:
67+
68+
```yaml
69+
model: # Model specific parameters
70+
base_params:
71+
model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # Model args that you would pass in the command line
72+
generation: # Generation specific parameters
73+
temperature: 0.3
74+
repetition_penalty: 1.0
75+
frequency_penalty: 0.0
76+
presence_penalty: 0.0
77+
seed: 42
78+
top_k: 0
79+
min_p: 0.0
80+
top_p: 0.9
81+
metric_options: # Optional metric arguments
82+
codegen_pass@1:16:
83+
num_samples: 16
84+
```
85+
86+
An optional key `metric_options` can be passed in the yaml file,
87+
using the name of the metric or metrics, as defined in the `Metric.metric_name`.
88+
In this case, the `codegen_pass@1:16` metric defined in our tasks will have the `num_samples` updated to 16,
89+
independently of the number defined by default.

src/lighteval/main_vllm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,11 @@ def vllm(
134134
with open(model_args, "r") as f:
135135
config = yaml.safe_load(f)["model"]
136136
model_args = config["base_params"]["model_args"]
137+
metric_options = config.get("metric_options", {})
137138
generation_parameters = GenerationParameters.from_dict(config)
138139
else:
139-
generation_parameters = GenerationParameters()
140+
generation_parameters = GenerationParameters.from_model_args(model_args)
141+
metric_options = {}
140142

141143
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
142144
model_config = VLLMModelConfig(**model_args_dict, generation_parameters=generation_parameters)
@@ -146,6 +148,7 @@ def vllm(
146148
pipeline_parameters=pipeline_params,
147149
evaluation_tracker=evaluation_tracker,
148150
model_config=model_config,
151+
metric_options=metric_options,
149152
)
150153

151154
pipeline.evaluate()

src/lighteval/models/model_input.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,34 @@ def from_dict(cls, config_dict: dict):
5959
"""
6060
return GenerationParameters(**config_dict.get("generation", {}))
6161

62+
@classmethod
63+
def from_model_args(cls, model_args: str):
64+
"""Creates a GenerationParameters object from a model_args string.
65+
66+
It's used when the model_args are passed as a string in the command line.
67+
The generation parameters must follow the following format (at any place in the string):
68+
"generation_parameters={key1:value1,key2=value2}"
69+
70+
Args:
71+
model_args (str): A string like the following:
72+
"pretrained=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B,dtype=float16,max_model_length=32768,generation={temperature:0.7,top_p:5}"
73+
"""
74+
75+
def parse_model_args(model_args):
76+
import json
77+
import re
78+
79+
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
80+
matches = pattern.findall(model_args)
81+
for key, value in matches:
82+
key = key.strip()
83+
if key == "generation_parameters":
84+
gen_params = re.sub(r"(\w+):", r'"\1":', value)
85+
return json.loads(gen_params)
86+
87+
params: dict = parse_model_args(model_args) or {}
88+
return GenerationParameters(**params)
89+
6290
def to_litellm_dict(self) -> dict:
6391
"""Selects relevant generation and sampling parameters for litellm models.
6492
Doc: https://docs.litellm.ai/docs/completion/input#input-params-1

src/lighteval/pipeline.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
evaluation_tracker: EvaluationTracker,
133133
model_config=None,
134134
model=None,
135+
metric_options=None,
135136
):
136137
if not (model or model_config):
137138
raise ValueError("Must provide either a model or model config when creating a pipeline.")
@@ -145,6 +146,7 @@ def __init__(
145146

146147
self.model_config = model_config
147148
self.evaluation_tracker = evaluation_tracker
149+
self._metric_options = metric_options or {}
148150
self.accelerator, self.parallel_context = self._init_parallelism_manager()
149151
self.model = self._init_model(model_config, model)
150152

@@ -209,6 +211,10 @@ def _init_tasks_and_requests(self, tasks: str):
209211
)
210212
task_names_list, fewshots_dict = taskinfo_selector(tasks, registry)
211213
task_dict = registry.get_task_dict(task_names_list)
214+
# If there are metric_options defined from the yaml file,
215+
# review if they have to be updated.
216+
if self._metric_options:
217+
self._update_num_samples(task_dict)
212218
LightevalTask.load_datasets(list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes)
213219

214220
self.evaluation_tracker.task_config_logger.log(task_dict)
@@ -230,6 +236,19 @@ def _init_tasks_and_requests(self, tasks: str):
230236
self.requests = requests
231237
self.docs = docs
232238

239+
def _update_num_samples(self, task_dict: dict[str, LightevalTask]):
240+
"""Helper function to update the num_samples of a given metric via the yaml file.
241+
As it has to be done at the metric level, it's better to update the value per metric.
242+
It will add a num_samples to the already defined metrics' num_samples if defined in the yaml file.
243+
As later when constructing the requests the max is taken over the num_samples, this is valid.
244+
"""
245+
for _, task in task_dict.items():
246+
for metric in task.metrics:
247+
if metric_data := self._metric_options.get(metric.metric_name, None):
248+
num_samples = metric_data.get("num_samples", None)
249+
if num_samples:
250+
task.num_samples = [num_samples]
251+
233252
def _init_random_seeds(self):
234253
logger.info("--- INIT SEEDS ---")
235254
random.seed(1234)

src/lighteval/tasks/extended/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
if can_load_extended_tasks():
2727
import lighteval.tasks.extended.ifeval.main as ifeval
28+
import lighteval.tasks.extended.lcb.main as lcb
2829
import lighteval.tasks.extended.mix_eval.main as mix_eval
2930
import lighteval.tasks.extended.mt_bench.main as mt_bench
3031
import lighteval.tasks.extended.olympiade_bench.main as olympiad_bench
3132
import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks
3233

33-
AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench]
34+
AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench, lcb]
3435

3536
else:
3637
AVAILABLE_EXTENDED_TASKS_MODULES = []

0 commit comments

Comments
 (0)