1- import os
2- from pathlib import Path
1+ from importlib import import_module
32from string import ascii_uppercase
43
54import yaml
65from huggingface_hub import hf_hub_download
76from inspect_ai import Epochs , Task , task
87from inspect_ai .dataset import FieldSpec , Sample , hf_dataset
9- from inspect_ai .scorer import choice , exact , match , model_graded_fact
10- from inspect_ai .solver import (
11- chain_of_thought ,
12- generate ,
13- multiple_choice ,
14- prompt_template ,
15- system_message ,
16- )
17-
18-
19- def load_config (yaml_path : str = None ) -> dict :
20- """Load and parse the YAML configuration file."""
21- if yaml_path is None :
22- yaml_path = os .getenv ("EVAL_YAML" , "eval.yaml" )
23-
24- yaml_path = Path (yaml_path )
25- if not yaml_path .is_absolute ():
26- yaml_path = Path (__file__ ).parent / yaml_path
27-
28- with open (yaml_path , "r" ) as f :
29- return yaml .safe_load (f )
308
319
3210def record_to_sample (record , field_spec : dict ):
@@ -51,9 +29,9 @@ def record_to_sample(record, field_spec: dict):
5129 return Sample (** sample_kwargs )
5230
5331
54- def load_dataset (repo_id : str , revision : str = "main" , task_config : dict = None , global_config : dict = None ):
32+ def load_dataset (repo_id : str , revision : str = "main" , task_config : dict = None ):
5533 """Load dataset based on task configuration."""
56- subset = task_config .get ("subset" )
34+ subset = task_config .get ("subset" , "default" )
5735 split = task_config .get ("splits" , "test" )
5836 field_spec = task_config ["field_spec" ]
5937
@@ -76,85 +54,115 @@ def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None,
7654 sample_fields = FieldSpec (
7755 input = field_spec ["input" ],
7856 target = field_spec ["target" ],
79- ** ({ k : v for k , v in field_spec .items () if k not in [ "input " , "target" ]} ),
57+ metadata = field_spec .get ( "metadata " , [] ),
8058 ),
8159 )
8260
8361 return dataset
8462
8563
8664def build_solvers (task_config : dict ):
87- """Build solvers list from task configuration."""
65+ """
66+ Build a list of solvers from the task configuration.
67+
68+ task_config example:
69+
70+ ```yaml
71+ solvers:
72+ - name: prompt_template
73+ args:
74+ template: >
75+ You are a helpful assistant.
76+ {prompt}
77+ - name: generate
78+ args:
79+ cache: true
80+ ```
81+
82+
83+ """
8884 solvers = []
89- solver_names = task_config .get ("solvers" , [])
90-
91- for solver_name in solver_names :
92- if solver_name == "prompt_template" :
93- if "prompt_template" in task_config and task_config ["prompt_template" ]:
94- template = task_config ["prompt_template" ].strip ().strip ('"' )
95- template = template .replace ("{{prompt}}" , "{prompt}" )
96- solvers .append (prompt_template (template ))
97- elif solver_name == "system_message" :
98- if "system_message" in task_config and task_config ["system_message" ]:
99- sys_msg = task_config ["system_message" ].strip ().strip ('"' )
100- solvers .append (system_message (sys_msg ))
101- elif solver_name == "chain_of_thought" :
102- solvers .append (chain_of_thought ())
103- elif solver_name == "multiple_choice" :
104- solvers .append (multiple_choice ())
105- elif solver_name == "generate" :
106- solvers .append (generate ())
85+ solver_configs = task_config .get ("solvers" , [])
86+ solver_module = import_module ("inspect_ai.solver" )
10787
108- return solvers
88+ for solver_config in solver_configs :
89+ solver_name = solver_config ["name" ]
10990
91+ if not hasattr (solver_module , solver_name ):
92+ raise ValueError (f"Unknown solver: { solver_name } " )
11093
111- def build_scorer (task_config : dict ):
112- """Build scorer from task configuration."""
113- scorer_name = task_config .get ("scorers" , ["choice" ])[0 ]
114-
115- if scorer_name == "choice" :
116- return choice ()
117- elif scorer_name == "exact" :
118- return exact ()
119- elif scorer_name == "match" :
120- return match ()
121- elif scorer_name == "model_graded_fact" :
122- return model_graded_fact ()
123- else :
124- raise ValueError (f"Unknown scorer: { scorer_name } " )
94+ solver_fn = getattr (solver_module , solver_name )
95+ solvers .append (solver_fn (** solver_config .get ("args" , {})))
96+
97+ return solvers
12598
12699
127- def create_task_from_config (
128- repo_id : str , revision : str = "main" , task_config : dict = None , global_config : dict = None
129- ):
100+ def build_scorer (task_config : dict ):
101+ """
102+ Build a scorer from the task configuration.
103+ task_config example:
104+
105+ ```yaml
106+ scorers:
107+ - name: model_graded_fact
108+ args:
109+ template: |
110+ grade this,
111+
112+ question:
113+ {question}
114+ criterion:
115+ {criterion}
116+ answer:
117+ {answer}
118+ ```
119+ """
120+ scorers = []
121+ scorer_configs = task_config .get ("scorers" , [])
122+ scorer_module = import_module ("inspect_ai.scorer" )
123+
124+ for scorer_config in scorer_configs :
125+ scorer_name = scorer_config ["name" ]
126+
127+ if not hasattr (scorer_module , scorer_name ):
128+ raise ValueError (f"Unknown scorer: { scorer_name } " )
129+
130+ scorer_fn = getattr (scorer_module , scorer_name )
131+ scorers .append (scorer_fn (** scorer_config .get ("args" , {})))
132+
133+ return scorers
134+
135+
136+ @task
137+ def create_task_from_config (repo_id : str , revision : str = "main" , task_config : dict = None ):
130138 """Create an inspect.ai Task from a task configuration."""
131- dataset = load_dataset (repo_id , revision , task_config , global_config )
139+ dataset = load_dataset (repo_id , revision , task_config )
132140 solvers = build_solvers (task_config )
133- scorer = build_scorer (task_config )
141+ scorers = build_scorer (task_config )
134142 epochs = task_config .get ("epochs" , 1 )
135143 epochs_reducer = task_config .get ("epochs_reducer" , "mean" )
136144
137145 return Task (
138146 dataset = dataset ,
139147 solver = solvers ,
140- scorer = scorer ,
148+ scorer = scorers ,
141149 name = task_config ["name" ],
142150 epochs = Epochs (epochs , epochs_reducer ),
143151 )
144152
145153
146- def create_task_function (repo_id : str , revision : str = "main" ):
154+ def create_task_function (repo_id : str , revision : str = "main" ) -> list :
147155 """Factory function to create a task function with proper closure."""
148156 # read yaml from hf filesystem
149157 yaml_path = hf_hub_download (repo_id = repo_id , filename = "eval.yaml" , repo_type = "dataset" , revision = revision )
150158
151159 with open (yaml_path , "r" ) as f :
152160 global_config = yaml .safe_load (f )
153161
154- task_config = global_config ["tasks" ][ 0 ]
162+ task_configs = global_config ["tasks" ]
155163
156- @ task
157- def task_func () :
158- return create_task_from_config (repo_id , revision , task_config , global_config )
164+ tasks = []
165+ for task_config in task_configs :
166+ tasks . append ( create_task_from_config (repo_id , revision , task_config ) )
159167
160- return task_func
168+ return tasks
0 commit comments