11import collections
22import random
3+ from dataclasses import dataclass
34from multiprocessing import Pool
45from pathlib import Path
5- from typing import TYPE_CHECKING , List , Optional , Tuple
6+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union
67
78from datasets import load_dataset
89
3940 from lighteval .logging .evaluation_tracker import EvaluationTracker
4041
4142
43+ @dataclass
44+ class LightevalTaskConfig :
45+ name : str
46+ prompt_function : str
47+ hf_repo : str
48+ hf_subset : str
49+ metric : Tuple [Union [str , Metrics ]]
50+ hf_avail_splits : Optional [Tuple [str ]] = None
51+ evaluation_splits : Optional [Tuple [str ]] = None
52+ few_shots_split : Optional [str ] = None
53+ few_shots_select : Optional [str ] = None
54+ generation_size : int = - 1
55+ stop_sequence : Optional [Tuple [str ]] = None
56+ output_regex : Optional [str ] = None
57+
58+ frozen : bool = False
59+ suite : Optional [Tuple [str ]] = None # we use this to know if we should use a custom lighteval or bigcode task
60+
61+ def as_dict (self ):
62+ return {
63+ "name" : self .name ,
64+ "prompt_function" : self .prompt_function ,
65+ "hf_repo" : self .hf_repo ,
66+ "hf_subset" : self .hf_subset ,
67+ "metric" : tuple (str (m ) for m in self .metric ),
68+ "hf_avail_splits" : self .hf_avail_splits ,
69+ "evaluation_splits" : self .evaluation_splits ,
70+ "few_shots_split" : self .few_shots_split ,
71+ "few_shots_select" : self .few_shots_select ,
72+ "generation_size" : self .generation_size ,
73+ "stop_sequence" : self .stop_sequence ,
74+ "output_regex" : self .output_regex ,
75+ "frozen" : self .frozen ,
76+ "suite" : self .suite ,
77+ }
78+
79+ def __post_init__ (self ):
80+ if self .suite is None :
81+ self .suite = ["custom" ]
82+ if self .hf_avail_splits is None :
83+ self .hf_avail_splits = ["train" , "validation" , "test" ]
84+ if self .evaluation_splits is None :
85+ self .evaluation_splits = ["validation" ]
86+ if self .stop_sequence is None :
87+ self .stop_sequence = ["\n " ]
88+
89+ # Convert list to tuple for hashing
90+ self .metric = tuple (self .metric )
91+ self .hf_avail_splits = tuple (self .hf_avail_splits ) if self .hf_avail_splits is not None else None
92+ self .evaluation_splits = tuple (self .evaluation_splits ) if self .evaluation_splits is not None else None
93+ self .suite = tuple (self .suite ) if self .suite is not None else None
94+ self .stop_sequence = tuple (self .stop_sequence ) if self .stop_sequence is not None else None
95+
96+
4297class LightevalTask :
43- def __init__ (self , name : str , cfg : dict , cache_dir : Optional [str ] = None , custom_tasks_module = None ):
98+ def __init__ (self , name : str , cfg : LightevalTaskConfig , cache_dir : Optional [str ] = None , custom_tasks_module = None ):
4499 """
45100 Initialize a LightEval task.
46101
@@ -60,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
60115 self ._cfg = cfg
61116
62117 # Dataset info
63- self .hf_repo = cfg [ " hf_repo" ]
64- self .hf_subset = cfg [ " hf_subset" ]
118+ self .hf_repo = cfg . hf_repo
119+ self .hf_subset = cfg . hf_subset
65120 self .dataset_path = self .hf_repo
66121 self .dataset_config_name = self .hf_subset
67122 self .dataset = None # Delayed download
@@ -70,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
70125 self ._docs = None
71126
72127 # Managing splits and few shot
73- self .all_available_splits = as_list (cfg [ " hf_avail_splits" ] )
74- if cfg .get ( " evaluation_splits" , None ) is None :
128+ self .all_available_splits = as_list (cfg . hf_avail_splits )
129+ if cfg .evaluation_splits is None :
75130 raise ValueError (f"The evaluation split for task { self .name } is None. Please select a valid split." )
76131
77- self .evaluation_split = as_list (cfg [ " evaluation_splits" ] )
78- if cfg .get ( " few_shots_split" , None ) is not None :
79- self .fewshot_split = as_list (cfg [ " few_shots_split" ] )
132+ self .evaluation_split = as_list (cfg . evaluation_splits )
133+ if cfg .few_shots_split is not None :
134+ self .fewshot_split = as_list (cfg . few_shots_split )
80135 else :
81136 self .fewshot_split = as_list (self .get_first_possible_fewshot_splits ())
82137 self .fewshot_sampler = FewShotSampler (
83- few_shots_select = cfg [ " few_shots_select" ] , few_shots_split = self .fewshot_split
138+ few_shots_select = cfg . few_shots_select , few_shots_split = self .fewshot_split
84139 )
85140
86141 # Metrics
87- self .metrics = as_list (cfg [ " metric" ] )
88- self .suite = as_list (cfg [ " suite" ] )
142+ self .metrics = as_list (cfg . metric )
143+ self .suite = as_list (cfg . suite )
89144 ignored = [metric for metric in self .metrics if Metrics [metric ].value .category == MetricCategory .IGNORED ]
90145 if len (ignored ) > 0 :
91146 hlog_warn (f"[WARNING] Not implemented yet: ignoring the metric { ' ,' .join (ignored )} for task { self .name } ." )
@@ -95,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
95150 # Data processing
96151 # to use once prompt formatting is managed as a module
97152 if custom_tasks_module is None :
98- self .formatter = getattr (tasks_prompt_formatting , cfg [ " prompt_function" ] )
99- elif hasattr (custom_tasks_module , cfg [ " prompt_function" ] ):
153+ self .formatter = getattr (tasks_prompt_formatting , cfg . prompt_function )
154+ elif hasattr (custom_tasks_module , cfg . prompt_function ):
100155 # If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting
101156 # We take the prompt from the custom_tasks_module
102- if hasattr (tasks_prompt_formatting , cfg [ " prompt_function" ] ):
157+ if hasattr (tasks_prompt_formatting , cfg . prompt_function ):
103158 hlog_warn (
104- f"Be careful you are using custom prompt function { cfg [ ' prompt_function' ] } and not the default one."
159+ f"Be careful you are using custom prompt function { cfg . prompt_function } and not the default one."
105160 )
106- self .formatter = getattr (custom_tasks_module , cfg [ " prompt_function" ] )
161+ self .formatter = getattr (custom_tasks_module , cfg . prompt_function )
107162 else :
108- self .formatter = getattr (tasks_prompt_formatting , cfg [ " prompt_function" ] )
109- self .generation_size = cfg [ " generation_size" ]
110- self .stop_sequence = cfg [ " stop_sequence" ]
111- self .output_regex = cfg [ " output_regex" ]
163+ self .formatter = getattr (tasks_prompt_formatting , cfg . prompt_function )
164+ self .generation_size = cfg . generation_size
165+ self .stop_sequence = cfg . stop_sequence
166+ self .output_regex = cfg . output_regex
112167
113168 # Save options
114169 self .save_queries : bool = False
0 commit comments