Skip to content

Commit 583898c

Browse files
authored
Merge pull request #11 from VectorInstitute/develop
Implement task generation pipeline
2 parents e13002a + 2b0f884 commit 583898c

File tree

13 files changed

+660
-125
lines changed

13 files changed

+660
-125
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ ignore = [
125125
[tool.ruff.lint.per-file-ignores]
126126
"__init__.py" = ["E402", "F401", "F403", "F811"]
127127
"tests/src/seed_capabilities/math/math_competition_algebra/capability.py" = ["D100", "D101", "D102"]
128+
"tests/src/capabilities_t2/math/math_mathematics_modeling_real_world/capability.py" = ["D100", "D101", "D102"]
128129
"src/run.py" = ["ERA001"]
129130
"src/lbo.py" = ["ERA001"]
131+
"src/utils/capability_utils.py" = ["ERA001"]
130132

131133
[tool.ruff.lint.pep8-naming]
132134
ignore-names = ["X*", "setUp"]

src/capability.py

Lines changed: 217 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
import importlib # noqa: D100
22
import json
33
import os
4+
import re
45
import sys
56
from collections import defaultdict
6-
from typing import Any, Dict, List
7+
from typing import Any, Dict, List, Tuple
78

89
from src.model import Model
910
from src.utils.capability_utils import parse_python_class_str, read_score_inspect_json
1011
from src.utils.constants import (
12+
NO_ANSWER_STR,
1113
NON_SEED_CAPABILITIES_SCORE_DIR,
1214
SEED_CAPABILITIES_SCORE_DIR,
15+
TAB_W_SPACES,
1316
)
1417
from src.utils.data_utils import load_data
18+
from src.utils.prompts import TASK_SOLVER_SYSTEM_PROMPT
1519

1620

1721
class CapabilitySeedDataset:
@@ -137,7 +141,10 @@ def from_dict(cls, capability_dict: Dict[str, Any], base_dir: str) -> "Capabilit
137141
f"capability_{c_dict['name']}", os.path.join(c_dir, "capability.py")
138142
)
139143
c_obj = c_module.Capability()
140-
initial_tasks = list(c_obj.repr_tasks().values())
144+
initial_tasks = [
145+
{"id": k, "problem": v["problem"], "answer": v["answer"]}
146+
for k, v in c_obj.repr_tasks().items()
147+
]
141148
template_instructions = c_obj.get_instructions({"problem": '{t["problem"]}'})
142149
template_instructions = f'f"""{template_instructions}"""'
143150

@@ -163,7 +170,7 @@ def _load_capability_json(self) -> None:
163170
self.domain = _cfg["capability_domain"]
164171
self.instructions = _cfg["capability_instructions"]
165172
# TODO: Store data is stored in json or elsewhere?
166-
self._data = _cfg["capability_data"]
173+
self._data: List[Dict[str, Any]] = _cfg["capability_data"]
167174
# Check if the capability is a seed capability, use source_dataset as indicator
168175
self.is_seed = "source_dataset" in _cfg
169176

@@ -209,6 +216,114 @@ def load_scores(self, scores_dir: str | None = None) -> Dict[str, float]:
209216
scores_dict[model] = read_score_inspect_json(scores_file)
210217
return scores_dict
211218

219+
def get_repr_tasks(self) -> List[Dict[str, Any]]:
220+
"""
221+
Get the representative tasks for the capability.
222+
223+
Returns
224+
-------
225+
List[Dict[Any]]: A list of dictionaries containing the representative tasks.
226+
Each task dict consists of id, problem, and answer keys.
227+
"""
228+
repr_tasks = []
229+
for task_id, task_data in self.capability_repr_class.repr_tasks().items():
230+
repr_tasks.append(
231+
{
232+
"id": task_id,
233+
"problem": task_data["problem"],
234+
"answer": task_data["answer"],
235+
}
236+
)
237+
return repr_tasks
238+
239+
def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
240+
"""
241+
Add and/or update tasks for the capability.
242+
243+
Args
244+
----
245+
tasks (List[Dict[str, Any]]): A list of dictionaries containing the tasks
246+
to be added. Each task dict consists of id, problem, and answer keys.
247+
"""
248+
if not all(
249+
"id" in task and "problem" in task and "answer" in task for task in tasks
250+
):
251+
raise ValueError(
252+
"Each task must contain 'id', 'problem', and 'answer' keys."
253+
)
254+
255+
existing_tasks = self.get_tasks()
256+
existing_task_ids = [task["id"] for task in existing_tasks]
257+
new_task_ids = [task["id"] for task in tasks]
258+
# Keep new task for overlapping tasks
259+
# TODO: Add `overwrite` flag to update existing tasks
260+
tasks_to_keep = [
261+
task
262+
for task in existing_tasks
263+
if task["id"]
264+
not in list(set.intersection(set(existing_task_ids), set(new_task_ids)))
265+
] + tasks
266+
# Sort by task id
267+
tasks_to_keep.sort(key=lambda x: x["id"])
268+
269+
# Check if the new task list consists of representative tasks
270+
# If yes, update the capability class python file
271+
repr_tasks = [
272+
task
273+
for task in tasks
274+
if task["id"] in self.capability_repr_class.repr_tasks()
275+
]
276+
if repr_tasks:
277+
partial_repr_task_ids = [task["id"] for task in repr_tasks]
278+
missing_repr_tasks = {
279+
k: v
280+
for k, v in self.capability_repr_class.repr_tasks().items()
281+
if k not in partial_repr_task_ids
282+
}
283+
for task_id, task_data in missing_repr_tasks.items():
284+
repr_tasks.append({"id": task_id, **task_data})
285+
repr_tasks.sort(key=lambda x: x["id"])
286+
# Update the capability class python file
287+
# Extract str which contains the repr_tasks dictionary
288+
# TODO: Since these are hardcoded, update when the format changes
289+
prefix_str = f"def repr_tasks() -> dict[str, dict]:\n{TAB_W_SPACES}{TAB_W_SPACES}return "
290+
suffix_str = f"\n\n{TAB_W_SPACES}@staticmethod\n{TAB_W_SPACES}def get_instructions(t: dict) -> str:"
291+
prev_repr_tasks_str = self.capability_repr_class_str.split(prefix_str)[
292+
1
293+
].split(suffix_str)[0]
294+
# Restructure to match the original format
295+
repr_tasks_dict = {}
296+
for elm in repr_tasks:
297+
repr_tasks_dict[elm["id"]] = {k: v for k, v in elm.items() if k != "id"}
298+
# Replace the repr_tasks dictionary in the capability class string
299+
# with the updated one
300+
updated_repr_tasks_str = json.dumps(repr_tasks_dict, indent=4)
301+
newline = "\n"
302+
capability_repr_class_str = self.capability_repr_class_str.lstrip(
303+
f"```python{newline}"
304+
).rstrip(f"{newline}```")
305+
capability_repr_class_str = capability_repr_class_str.replace(
306+
prev_repr_tasks_str,
307+
updated_repr_tasks_str,
308+
)
309+
with open(os.path.join(self.source_dir, "capability.py"), "w") as f:
310+
f.write(capability_repr_class_str)
311+
312+
# Update the capability data in the capability json file
313+
c_dict = {
314+
"capability_name": self.name,
315+
"capability_description": self.description,
316+
"capability_domain": self.domain,
317+
"capability_instructions": self.instructions,
318+
"capability_data": tasks_to_keep,
319+
}
320+
with open(os.path.join(self.source_dir, "capability.json"), "w") as f:
321+
json.dump(c_dict, f, indent=4)
322+
323+
# Reload the capability class to reflect these changes
324+
self._load_capability_json()
325+
self._load_capability_repr_class()
326+
212327
def _to_dict(self) -> Dict[str, Any]:
213328
return {
214329
"name": self.name,
@@ -252,6 +367,105 @@ def encode(self, encoder_model: Any) -> None:
252367
self.encoding = None
253368
raise NotImplementedError
254369

370+
def _solve_task(
371+
self, task: Dict[str, Any], llm: Model, gen_cfg: Dict[str, Any]
372+
) -> Tuple[str, Dict[str, Any]]:
373+
"""
374+
Solve the task using the given LLM.
375+
376+
Args
377+
----
378+
task (Dict[str, Any]): The task dictionary containing the ID
379+
and the problem to solve.
380+
llm (Model): The LLM to use for solving the task.
381+
gen_cfg (Dict[str, Any]): The generation configuration for the LLM.
382+
383+
Returns
384+
-------
385+
Tuple[str, Dict[str, Any]]: A tuple containing the answer as a string
386+
and metadata as a dictionary, which includes raw response and
387+
input/output tokens.
388+
"""
389+
# Generate answer using the LLM
390+
# TODO:
391+
# 1. Enable tool use
392+
# 2. How to link this function with the Inspect Solver
393+
# to be used in _evaluate_using_inspect()?
394+
print(f"Solving task {task['id']} ...")
395+
sys_prompt = TASK_SOLVER_SYSTEM_PROMPT.format(
396+
capability_name=self.name, capability_domain=self.domain
397+
)
398+
user_prompt = self.capability_repr_class.get_instructions(task)
399+
response, metadata = llm.generate(
400+
sys_prompt=sys_prompt,
401+
user_prompt=user_prompt,
402+
generation_config=gen_cfg,
403+
)
404+
# Extract answer from response
405+
# Borrowed from:
406+
# https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/src/inspect_ai/_util/pattern.py#L3
407+
# TODO:
408+
# 1. Dynamically set pattern based on capability instructions
409+
# 2. For some capabilities the reasoning is the answer and the actual answer
410+
# is only a final statement, how to handle this?
411+
# 3. How to gracefully handle cases where tokens are insufficient
412+
# and the answer is incomplete?
413+
answer_pattern = r"(?i)ANSWER\s*:\s*([^\n]+)"
414+
match = re.search(answer_pattern, response)
415+
answer = match.group(1) if match else NO_ANSWER_STR
416+
metadata = {
417+
"raw_response": response,
418+
"api_metadata": metadata,
419+
}
420+
return (answer, metadata)
421+
422+
def solve_tasks(
423+
self, tasks: List[Dict[str, Any]], llm: Model, gen_cfg: Dict[str, Any]
424+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
425+
"""
426+
Solve the tasks using the given LLM.
427+
428+
Args
429+
----
430+
tasks (List[Dict[str, Any]]): The list of tasks to solve.
431+
llm (Model): The LLM to use for solving the tasks.
432+
gen_cfg (Dict[str, Any]): The generation configuration for the LLM.
433+
434+
Returns
435+
-------
436+
Tuple[List[Dict[str, Any]], Dict[str, Any]]: A tuple containing a list of
437+
dictionaries with the solved tasks and a dictionary with metadata
438+
for each task.
439+
"""
440+
solved_tasks = []
441+
metadata = {}
442+
for task in tasks:
443+
answer, _metadata = self._solve_task(
444+
task=task,
445+
llm=llm,
446+
gen_cfg=gen_cfg,
447+
)
448+
solved_tasks.append(
449+
{
450+
"id": task["id"],
451+
"problem": task["problem"],
452+
"answer": answer,
453+
"reasoning": _metadata["raw_response"],
454+
}
455+
)
456+
metadata[task["id"]] = _metadata["api_metadata"]
457+
return (solved_tasks, metadata)
458+
459+
def get_tasks(self) -> List[Dict[str, Any]]:
460+
"""
461+
Get the existing tasks for the capability.
462+
463+
Returns
464+
-------
465+
List[Dict[str, Any]]: A list of dictionaries containing the tasks.
466+
"""
467+
return self._data
468+
255469
def _create_inspect_file(self) -> None:
256470
"""
257471
Implement pipeline to evaluate the capability using the inspect framework.

src/cfg/run_cfg.yaml

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
scientist_llm:
22
name: gpt-4o-mini
3-
gen_cfg:
4-
temperature: 0.7
5-
max_tokens: 64
3+
generation_cfg:
4+
capability_generation:
5+
temperature: 0.7
6+
max_tokens: 64
7+
task_generation:
8+
temperature: 0.7
9+
max_tokens: 64
10+
task_solve:
11+
temperature: 0.7
12+
max_tokens: 64
613

714
subject_llm:
815
name: Meta-Llama-3.1-70B-Instruct
@@ -14,16 +21,24 @@ capabilities_cfg:
1421
capabilities_dir: /fs01/projects/aieng/public/ace/artifacts
1522
results_dir: /fs01/projects/aieng/public/ace/artifacts
1623
domain: math
24+
# Number of seed capabilities to use for initial capability generation
25+
# Set to -1 to use all seed capabilities
1726
num_seed_capabilities: -1
27+
# Number of initial capabilities to generate using the scientist LLM
1828
num_gen_capabilities: 4
29+
# Number of initial capabilities to generate per run
1930
num_gen_capabilities_per_run: 2
20-
num_gen_tasks_per_capability: 2
31+
# Number of tasks to generate for each capability
32+
num_gen_tasks_per_capability: 1
33+
# Set this flag to true to use representative tasks
34+
# as few shot examples for task generation
35+
task_gen_few_shot: true
2136

2237
lbo_cfg:
2338
# Number of capabilities to generate using LBO
2439
num_lbo_runs: 1
2540
# Type of LBO pipeline to use
26-
pipeline_id: "nearest_neighbor" # "nearest_neighbor" or "discover_new"
41+
pipeline_id: "discover_new" # "nearest_neighbor" or "discover_new"
2742
# Train args for 'nearest_neighbor' pipeline
2843
train_frac: 0.5
2944
min_train_size: 10

src/generate_capabilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def generate_capabilities_using_llm(
191191
print(f"Metadata: {metadata}")
192192

193193
parsed_response = extract_and_parse_response(response)
194-
gen_capabilities = parsed_response["capabilities"]
194+
gen_capabilities = parsed_response["parsed_response"]
195195
gen_capabilities = [
196196
Capability.from_dict(capability_dict=capability, base_dir=base_capability_dir)
197197
for capability in gen_capabilities

0 commit comments

Comments
 (0)