Skip to content

Commit 31978fb

Browse files
committed
Add flag for solving sample tasks and fix start id
1 parent 6df1f35 commit 31978fb

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

src/capability.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _load_capability_json(self) -> None:
170170
self.domain = _cfg["capability_domain"]
171171
self.instructions = _cfg["capability_instructions"]
172172
# TODO: Store data is stored in json or elsewhere?
173-
self._data = _cfg["capability_data"]
173+
self._data: List[Dict[str, Any]] = _cfg["capability_data"]
174174
# Check if the capability is a seed capability, use source_dataset as indicator
175175
self.is_seed = "source_dataset" in _cfg
176176

@@ -252,13 +252,14 @@ def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
252252
"Each task must contain 'id', 'problem', and 'answer' keys."
253253
)
254254

255-
existing_task_ids = [task["id"] for task in self._data]
255+
existing_tasks = self.get_tasks()
256+
existing_task_ids = [task["id"] for task in existing_tasks]
256257
new_task_ids = [task["id"] for task in tasks]
257258
# Keep new task for overlapping tasks
258259
# TODO: Add `overwrite` flag to update existing tasks
259260
tasks_to_keep = [
260261
task
261-
for task in self._data
262+
for task in existing_tasks
262263
if task["id"]
263264
not in list(set.intersection(set(existing_task_ids), set(new_task_ids)))
264265
] + tasks
@@ -457,6 +458,16 @@ def solve_tasks(
457458
metadata[task["id"]] = _metadata["api_metadata"]
458459
return (solved_tasks, metadata)
459460

461+
def get_tasks(self) -> List[Dict[str, Any]]:
462+
"""
463+
Get the existing tasks for the capability.
464+
465+
Returns
466+
-------
467+
List[Dict[str, Any]]: A list of dictionaries containing the tasks.
468+
"""
469+
return self._data
470+
460471
def _create_inspect_file(self) -> None:
461472
"""
462473
Implement pipeline to evaluate the capability using the inspect framework.

src/generate_tasks.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def generate_tasks_using_llm(
7373
num_tasks: int,
7474
scientist_llm_gen_cfg_task_gen: Dict[str, Any],
7575
scientist_llm_gen_cfg_task_solve: Dict[str, Any],
76+
solve_sample_tasks: bool = False,
7677
**kwargs: Any,
7778
) -> None:
7879
"""
@@ -91,6 +92,7 @@ def generate_tasks_using_llm(
9192
for task generation using the scientist LLM.
9293
scientist_llm_gen_cfg_task_solve (Dict[str, Any]): The generation configuration
9394
for solving tasks using the scientist LLM.
95+
solve_sample_tasks (bool, optional): Whether to solve sample tasks.
9496
**kwargs (Any): Additional arguments for task generation.
9597
"""
9698
# TODO: Implement the function with the following components
@@ -122,9 +124,6 @@ def generate_tasks_using_llm(
122124
# Generate task problems
123125
# Extract sample tasks from representative tasks
124126
sample_tasks = capability.get_repr_tasks()
125-
for task in sample_tasks:
126-
# Remove the answer
127-
task.pop("answer", None)
128127

129128
# Generate new tasks using the scientist LLM
130129
sys_prompt, user_prompt = get_task_generation_prompt(
@@ -144,14 +143,20 @@ def generate_tasks_using_llm(
144143
print(f"Metadata: {task_gen_metadata}")
145144
parsed_response = extract_and_parse_response(response)
146145
new_tasks = parsed_response["parsed_response"]
147-
# Combine with sample tasks to get the full set of tasks
148-
start_id = len(sample_tasks) + 1
149-
all_tasks = sample_tasks + [
146+
147+
# Solve task and generate answers
148+
# Set starting ID for new tasks
149+
start_id = len(capability.get_tasks()) + 1
150+
all_tasks = [
150151
{"id": str(start_id + idx), "problem": new_tasks[idx]}
151152
for idx in range(len(new_tasks))
152153
]
153-
154-
# Solve task and generate answers
154+
# Add sample tasks if solving them
155+
if solve_sample_tasks:
156+
for task in sample_tasks:
157+
# Remove the answer
158+
task.pop("answer", None)
159+
all_tasks = sample_tasks + all_tasks
155160
solved_tasks, task_solver_metadata = capability.solve_tasks(
156161
tasks=all_tasks,
157162
llm=scientist_llm,

src/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def main(cfg: DictConfig) -> None:
115115
num_tasks=cfg.capabilities_cfg.num_gen_tasks_per_capability,
116116
scientist_llm_gen_cfg_task_gen=cfg.scientist_llm.gen_cfg.task_gen,
117117
scientist_llm_gen_cfg_task_solve=cfg.scientist_llm.gen_cfg.task_solve,
118+
solve_sample_tasks=True,
118119
few_shot=cfg.capabilities_cfg.task_gen_few_shot,
119120
)
120121
# # Evaluate subject LLM on each capability
@@ -143,6 +144,8 @@ def main(cfg: DictConfig) -> None:
143144
# num_tasks=cfg.capabilities_cfg.num_gen_tasks_per_capability,
144145
# scientist_llm_gen_cfg_task_gen=cfg.scientist_llm.gen_cfg.task_gen,
145146
# scientist_llm_gen_cfg_task_solve=cfg.scientist_llm.gen_cfg.task_solve,
147+
# solve_sample_tasks=True,
148+
# few_shot=cfg.capabilities_cfg.task_gen_few_shot,
146149
# )
147150
# # Evaluate subject LLM on new capability
148151
# new_capability.evaluate([subject_llm])

0 commit comments

Comments
 (0)