Skip to content

Commit bc0b8bc

Browse files
thomwolfclefourrierNathanHB
authored
Last PR to make custom tasks work for everyone (#23)
Small one to be in the release. Can iterate later if needed --------- Co-authored-by: [email protected] <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]> Co-authored-by: Nathan Habib <[email protected]>
1 parent f8bd2ab commit bc0b8bc

File tree

11 files changed

+177
-275
lines changed

11 files changed

+177
-275
lines changed

run_evals_accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_parser():
7171
parser.add_argument("--override_batch_size", type=int, default=-1)
7272
parser.add_argument("--dataset_loading_processes", type=int, default=1)
7373
parser.add_argument(
74-
"--custom_tasks_file",
74+
"--custom_tasks",
7575
type=str,
7676
default=None,
7777
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",

run_evals_nanotron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_parser():
2020
parser.add_argument(
2121
"--cache-dir",
2222
type=str,
23-
default="",
23+
default=None,
2424
help="Cache directory",
2525
)
2626

src/lighteval/logging/info_loggers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ class GeneralConfigLogger:
6969

7070
def __init__(self) -> None:
7171
"""Stores the current lighteval commit for reproducibility, and starts the evaluation timer."""
72-
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
73-
self.lighteval_sha = repo.git.rev_parse("HEAD")
72+
try:
73+
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
74+
except git.InvalidGitRepositoryError:
75+
repo = None
76+
77+
self.lighteval_sha = repo.git.rev_parse("HEAD") if repo is not None else "?"
7478
self.start_time = time.perf_counter()
7579

7680
def log_args_info(
@@ -543,5 +547,5 @@ def log(self, task_dict: dict[str, LightevalTask]) -> None:
543547
self.tasks_configs = {name: task.cfg for name, task in task_dict.items()}
544548

545549
def log_num_docs(self, task_name: str, original_num_docs: int, effective_num_docs: int) -> None:
546-
self.tasks_configs[task_name]["original_num_docs"] = original_num_docs
547-
self.tasks_configs[task_name]["effective_num_docs"] = effective_num_docs
550+
self.tasks_configs[task_name].original_num_docs = original_num_docs
551+
self.tasks_configs[task_name].effective_num_docs = effective_num_docs

src/lighteval/main_accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main(args):
6060
with accelerator.main_process_first() if accelerator is not None else nullcontext():
6161
task_names_list, few_shots_dict = taskinfo_selector(args.tasks)
6262
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
63-
task_names_list, custom_tasks_file=args.custom_tasks_file
63+
task_names_list, custom_tasks=args.custom_tasks
6464
)
6565
# Loading all the dataset in a distributed manner
6666
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)

src/lighteval/main_nanotron.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
def main(
3939
checkpoint_config_path: str,
4040
lighteval_config_path: Optional[str] = None,
41-
cache_dir: str = None,
41+
cache_dir: Optional[str] = None,
4242
config_cls: Type = Config,
4343
model_config_cls: Optional[Type] = None,
4444
model_cls: Optional[Type] = None,
@@ -109,14 +109,14 @@ def main(
109109
with htrack_block("Tasks loading"):
110110
with local_ranks_zero_first():
111111
tasks_selection = lighteval_config.tasks.tasks
112-
if lighteval_config.tasks.custom_tasks_file:
113-
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
112+
if lighteval_config.tasks.custom_tasks:
113+
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks)
114114
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
115115
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
116116

117117
task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
118118
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
119-
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
119+
task_names_list, custom_tasks=lighteval_config.tasks.custom_tasks
120120
)
121121
# Loading all the dataset in a distributed manner
122122
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)

src/lighteval/models/base_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
LoglikelihoodSingleTokenRequest,
2323
Request,
2424
)
25-
from lighteval.utils import (
26-
is_accelerate_available,
27-
)
25+
from lighteval.utils import as_list, is_accelerate_available
2826
from lighteval.utils_parallelism import find_executable_batch_size
2927

3028

@@ -342,7 +340,7 @@ def greedy_until(
342340
list[GenerateReturn]: list of generated responses.
343341
"""
344342
for request in requests:
345-
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
343+
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
346344
request.tokenized_context = self.tok_encode(request.context)
347345

348346
dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)

src/lighteval/tasks/lighteval_task.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import collections
22
import random
3+
from dataclasses import dataclass
34
from multiprocessing import Pool
45
from pathlib import Path
5-
from typing import TYPE_CHECKING, List, Optional, Tuple
6+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
67

78
from datasets import load_dataset
89

@@ -39,8 +40,62 @@
3940
from lighteval.logging.evaluation_tracker import EvaluationTracker
4041

4142

43+
@dataclass
44+
class LightevalTaskConfig:
45+
name: str
46+
prompt_function: str
47+
hf_repo: str
48+
hf_subset: str
49+
metric: Tuple[Union[str, Metrics]]
50+
hf_avail_splits: Optional[Tuple[str]] = None
51+
evaluation_splits: Optional[Tuple[str]] = None
52+
few_shots_split: Optional[str] = None
53+
few_shots_select: Optional[str] = None
54+
generation_size: int = -1
55+
stop_sequence: Optional[Tuple[str]] = None
56+
output_regex: Optional[str] = None
57+
58+
frozen: bool = False
59+
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task
60+
61+
def as_dict(self):
62+
return {
63+
"name": self.name,
64+
"prompt_function": self.prompt_function,
65+
"hf_repo": self.hf_repo,
66+
"hf_subset": self.hf_subset,
67+
"metric": tuple(str(m) for m in self.metric),
68+
"hf_avail_splits": self.hf_avail_splits,
69+
"evaluation_splits": self.evaluation_splits,
70+
"few_shots_split": self.few_shots_split,
71+
"few_shots_select": self.few_shots_select,
72+
"generation_size": self.generation_size,
73+
"stop_sequence": self.stop_sequence,
74+
"output_regex": self.output_regex,
75+
"frozen": self.frozen,
76+
"suite": self.suite,
77+
}
78+
79+
def __post_init__(self):
80+
if self.suite is None:
81+
self.suite = ["custom"]
82+
if self.hf_avail_splits is None:
83+
self.hf_avail_splits = ["train", "validation", "test"]
84+
if self.evaluation_splits is None:
85+
self.evaluation_splits = ["validation"]
86+
if self.stop_sequence is None:
87+
self.stop_sequence = ["\n"]
88+
89+
# Convert list to tuple for hashing
90+
self.metric = tuple(self.metric)
91+
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
92+
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None
93+
self.suite = tuple(self.suite) if self.suite is not None else None
94+
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None
95+
96+
4297
class LightevalTask:
43-
def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None):
98+
def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None):
4499
"""
45100
Initialize a LightEval task.
46101
@@ -60,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
60115
self._cfg = cfg
61116

62117
# Dataset info
63-
self.hf_repo = cfg["hf_repo"]
64-
self.hf_subset = cfg["hf_subset"]
118+
self.hf_repo = cfg.hf_repo
119+
self.hf_subset = cfg.hf_subset
65120
self.dataset_path = self.hf_repo
66121
self.dataset_config_name = self.hf_subset
67122
self.dataset = None # Delayed download
@@ -70,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
70125
self._docs = None
71126

72127
# Managing splits and few shot
73-
self.all_available_splits = as_list(cfg["hf_avail_splits"])
74-
if cfg.get("evaluation_splits", None) is None:
128+
self.all_available_splits = as_list(cfg.hf_avail_splits)
129+
if cfg.evaluation_splits is None:
75130
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")
76131

77-
self.evaluation_split = as_list(cfg["evaluation_splits"])
78-
if cfg.get("few_shots_split", None) is not None:
79-
self.fewshot_split = as_list(cfg["few_shots_split"])
132+
self.evaluation_split = as_list(cfg.evaluation_splits)
133+
if cfg.few_shots_split is not None:
134+
self.fewshot_split = as_list(cfg.few_shots_split)
80135
else:
81136
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
82137
self.fewshot_sampler = FewShotSampler(
83-
few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split
138+
few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split
84139
)
85140

86141
# Metrics
87-
self.metrics = as_list(cfg["metric"])
88-
self.suite = as_list(cfg["suite"])
142+
self.metrics = as_list(cfg.metric)
143+
self.suite = as_list(cfg.suite)
89144
ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED]
90145
if len(ignored) > 0:
91146
hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")
@@ -95,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
95150
# Data processing
96151
# to use once prompt formatting is managed as a module
97152
if custom_tasks_module is None:
98-
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
99-
elif hasattr(custom_tasks_module, cfg["prompt_function"]):
153+
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
154+
elif hasattr(custom_tasks_module, cfg.prompt_function):
100155
# If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting
101156
# We take the prompt from the custom_tasks_module
102-
if hasattr(tasks_prompt_formatting, cfg["prompt_function"]):
157+
if hasattr(tasks_prompt_formatting, cfg.prompt_function):
103158
hlog_warn(
104-
f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one."
159+
f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one."
105160
)
106-
self.formatter = getattr(custom_tasks_module, cfg["prompt_function"])
161+
self.formatter = getattr(custom_tasks_module, cfg.prompt_function)
107162
else:
108-
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
109-
self.generation_size = cfg["generation_size"]
110-
self.stop_sequence = cfg["stop_sequence"]
111-
self.output_regex = cfg["output_regex"]
163+
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
164+
self.generation_size = cfg.generation_size
165+
self.stop_sequence = cfg.stop_sequence
166+
self.output_regex = cfg.output_regex
112167

113168
# Save options
114169
self.save_queries: bool = False

0 commit comments

Comments
 (0)