11import importlib # noqa: D100
22import json
33import os
4+ import re
45import sys
56from collections import defaultdict
6- from typing import Any , Dict , List
7+ from typing import Any , Dict , List , Tuple
78
89from src .model import Model
910from src .utils .capability_utils import parse_python_class_str , read_score_inspect_json
1011from src .utils .constants import (
12+ NO_ANSWER_STR ,
1113 NON_SEED_CAPABILITIES_SCORE_DIR ,
1214 SEED_CAPABILITIES_SCORE_DIR ,
15+ TAB_W_SPACES ,
1316)
1417from src .utils .data_utils import load_data
18+ from src .utils .prompts import TASK_SOLVER_SYSTEM_PROMPT
1519
1620
1721class CapabilitySeedDataset :
@@ -137,7 +141,10 @@ def from_dict(cls, capability_dict: Dict[str, Any], base_dir: str) -> "Capabilit
137141 f"capability_{ c_dict ['name' ]} " , os .path .join (c_dir , "capability.py" )
138142 )
139143 c_obj = c_module .Capability ()
140- initial_tasks = list (c_obj .repr_tasks ().values ())
144+ initial_tasks = [
145+ {"id" : k , "problem" : v ["problem" ], "answer" : v ["answer" ]}
146+ for k , v in c_obj .repr_tasks ().items ()
147+ ]
141148 template_instructions = c_obj .get_instructions ({"problem" : '{t["problem"]}' })
142149 template_instructions = f'f"""{ template_instructions } """'
143150
@@ -163,7 +170,7 @@ def _load_capability_json(self) -> None:
163170 self .domain = _cfg ["capability_domain" ]
164171 self .instructions = _cfg ["capability_instructions" ]
165172 # TODO: Store data is stored in json or elsewhere?
166- self ._data = _cfg ["capability_data" ]
173+ self ._data : List [ Dict [ str , Any ]] = _cfg ["capability_data" ]
167174 # Check if the capability is a seed capability, use source_dataset as indicator
168175 self .is_seed = "source_dataset" in _cfg
169176
@@ -209,6 +216,114 @@ def load_scores(self, scores_dir: str | None = None) -> Dict[str, float]:
209216 scores_dict [model ] = read_score_inspect_json (scores_file )
210217 return scores_dict
211218
219+ def get_repr_tasks (self ) -> List [Dict [str , Any ]]:
220+ """
221+ Get the representative tasks for the capability.
222+
223+ Returns
224+ -------
225+ List[Dict[Any]]: A list of dictionaries containing the representative tasks.
226+ Each task dict consists of id, problem, and answer keys.
227+ """
228+ repr_tasks = []
229+ for task_id , task_data in self .capability_repr_class .repr_tasks ().items ():
230+ repr_tasks .append (
231+ {
232+ "id" : task_id ,
233+ "problem" : task_data ["problem" ],
234+ "answer" : task_data ["answer" ],
235+ }
236+ )
237+ return repr_tasks
238+
239+ def add_and_update_tasks (self , tasks : List [Dict [str , Any ]]) -> None :
240+ """
241+ Add and/or update tasks for the capability.
242+
243+ Args
244+ ----
245+ tasks (List[Dict[str, Any]]): A list of dictionaries containing the tasks
246+ to be added. Each task dict consists of id, problem, and answer keys.
247+ """
248+ if not all (
249+ "id" in task and "problem" in task and "answer" in task for task in tasks
250+ ):
251+ raise ValueError (
252+ "Each task must contain 'id', 'problem', and 'answer' keys."
253+ )
254+
255+ existing_tasks = self .get_tasks ()
256+ existing_task_ids = [task ["id" ] for task in existing_tasks ]
257+ new_task_ids = [task ["id" ] for task in tasks ]
258+ # Keep new task for overlapping tasks
259+ # TODO: Add `overwrite` flag to update existing tasks
260+ tasks_to_keep = [
261+ task
262+ for task in existing_tasks
263+ if task ["id" ]
264+ not in list (set .intersection (set (existing_task_ids ), set (new_task_ids )))
265+ ] + tasks
266+ # Sort by task id
267+ tasks_to_keep .sort (key = lambda x : x ["id" ])
268+
269+ # Check if the new task list consists of representative tasks
270+ # If yes, update the capability class python file
271+ repr_tasks = [
272+ task
273+ for task in tasks
274+ if task ["id" ] in self .capability_repr_class .repr_tasks ()
275+ ]
276+ if repr_tasks :
277+ partial_repr_task_ids = [task ["id" ] for task in repr_tasks ]
278+ missing_repr_tasks = {
279+ k : v
280+ for k , v in self .capability_repr_class .repr_tasks ().items ()
281+ if k not in partial_repr_task_ids
282+ }
283+ for task_id , task_data in missing_repr_tasks .items ():
284+ repr_tasks .append ({"id" : task_id , ** task_data })
285+ repr_tasks .sort (key = lambda x : x ["id" ])
286+ # Update the capability class python file
287+ # Extract str which contains the repr_tasks dictionary
288+ # TODO: Since these are hardcoded, update when the format changes
289+ prefix_str = f"def repr_tasks() -> dict[str, dict]:\n { TAB_W_SPACES } { TAB_W_SPACES } return "
290+ suffix_str = f"\n \n { TAB_W_SPACES } @staticmethod\n { TAB_W_SPACES } def get_instructions(t: dict) -> str:"
291+ prev_repr_tasks_str = self .capability_repr_class_str .split (prefix_str )[
292+ 1
293+ ].split (suffix_str )[0 ]
294+ # Restructure to match the original format
295+ repr_tasks_dict = {}
296+ for elm in repr_tasks :
297+ repr_tasks_dict [elm ["id" ]] = {k : v for k , v in elm .items () if k != "id" }
298+ # Replace the repr_tasks dictionary in the capability class string
299+ # with the updated one
300+ updated_repr_tasks_str = json .dumps (repr_tasks_dict , indent = 4 )
301+ newline = "\n "
302+ capability_repr_class_str = self .capability_repr_class_str .lstrip (
303+ f"```python{ newline } "
304+ ).rstrip (f"{ newline } ```" )
305+ capability_repr_class_str = capability_repr_class_str .replace (
306+ prev_repr_tasks_str ,
307+ updated_repr_tasks_str ,
308+ )
309+ with open (os .path .join (self .source_dir , "capability.py" ), "w" ) as f :
310+ f .write (capability_repr_class_str )
311+
312+ # Update the capability data in the capability json file
313+ c_dict = {
314+ "capability_name" : self .name ,
315+ "capability_description" : self .description ,
316+ "capability_domain" : self .domain ,
317+ "capability_instructions" : self .instructions ,
318+ "capability_data" : tasks_to_keep ,
319+ }
320+ with open (os .path .join (self .source_dir , "capability.json" ), "w" ) as f :
321+ json .dump (c_dict , f , indent = 4 )
322+
323+ # Reload the capability class to reflect these changes
324+ self ._load_capability_json ()
325+ self ._load_capability_repr_class ()
326+
212327 def _to_dict (self ) -> Dict [str , Any ]:
213328 return {
214329 "name" : self .name ,
@@ -252,6 +367,105 @@ def encode(self, encoder_model: Any) -> None:
252367 self .encoding = None
253368 raise NotImplementedError
254369
370+ def _solve_task (
371+ self , task : Dict [str , Any ], llm : Model , gen_cfg : Dict [str , Any ]
372+ ) -> Tuple [str , Dict [str , Any ]]:
373+ """
374+ Solve the task using the given LLM.
375+
376+ Args
377+ ----
378+ task (Dict[str, Any]): The task dictionary containing the ID
379+ and the problem to solve.
380+ llm (Model): The LLM to use for solving the task.
381+ gen_cfg (Dict[str, Any]): The generation configuration for the LLM.
382+
383+ Returns
384+ -------
385+ Tuple[str, Dict[str, Any]]: A tuple containing the answer as a string
386+ and metadata as a dictionary, which includes raw response and
387+ input/output tokens.
388+ """
389+ # Generate answer using the LLM
390+ # TODO:
391+ # 1. Enable tool use
392+ # 2. How to link this function with the Inspect Solver
393+ # to be used in _evaluate_using_inspect()?
394+ print (f"Solving task { task ['id' ]} ..." )
395+ sys_prompt = TASK_SOLVER_SYSTEM_PROMPT .format (
396+ capability_name = self .name , capability_domain = self .domain
397+ )
398+ user_prompt = self .capability_repr_class .get_instructions (task )
399+ response , metadata = llm .generate (
400+ sys_prompt = sys_prompt ,
401+ user_prompt = user_prompt ,
402+ generation_config = gen_cfg ,
403+ )
404+ # Extract answer from response
405+ # Borrowed from:
406+ # https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/src/inspect_ai/_util/pattern.py#L3
407+ # TODO:
408+ # 1. Dynamically set pattern based on capability instructions
409+ # 2. For some capabilities the reasoning is the answer and the actual answer
410+ # is only a final statement, how to handle this?
411+ # 3. How to gracefully handle cases where tokens are insufficient
412+ # and the answer is incomplete?
413+ answer_pattern = r"(?i)ANSWER\s*:\s*([^\n]+)"
414+ match = re .search (answer_pattern , response )
415+ answer = match .group (1 ) if match else NO_ANSWER_STR
416+ metadata = {
417+ "raw_response" : response ,
418+ "api_metadata" : metadata ,
419+ }
420+ return (answer , metadata )
421+
422+ def solve_tasks (
423+ self , tasks : List [Dict [str , Any ]], llm : Model , gen_cfg : Dict [str , Any ]
424+ ) -> Tuple [List [Dict [str , Any ]], Dict [str , Any ]]:
425+ """
426+ Solve the tasks using the given LLM.
427+
428+ Args
429+ ----
430+ tasks (List[Dict[str, Any]]): The list of tasks to solve.
431+ llm (Model): The LLM to use for solving the tasks.
432+ gen_cfg (Dict[str, Any]): The generation configuration for the LLM.
433+
434+ Returns
435+ -------
436+ Tuple[List[Dict[str, Any]], Dict[str, Any]]: A tuple containing a list of
437+ dictionaries with the solved tasks and a dictionary with metadata
438+ for each task.
439+ """
440+ solved_tasks = []
441+ metadata = {}
442+ for task in tasks :
443+ answer , _metadata = self ._solve_task (
444+ task = task ,
445+ llm = llm ,
446+ gen_cfg = gen_cfg ,
447+ )
448+ solved_tasks .append (
449+ {
450+ "id" : task ["id" ],
451+ "problem" : task ["problem" ],
452+ "answer" : answer ,
453+ "reasoning" : _metadata ["raw_response" ],
454+ }
455+ )
456+ metadata [task ["id" ]] = _metadata ["api_metadata" ]
457+ return (solved_tasks , metadata )
458+
459+ def get_tasks (self ) -> List [Dict [str , Any ]]:
460+ """
461+ Get the existing tasks for the capability.
462+
463+ Returns
464+ -------
465+ List[Dict[str, Any]]: A list of dictionaries containing the tasks.
466+ """
467+ return self ._data
468+
255469 def _create_inspect_file (self ) -> None :
256470 """
257471 Implement pipeline to evaluate the capability using the inspect framework.
0 commit comments