Skip to content

Commit e9625a1

Browse files
committed
Implement logic to add and update tasks
1 parent a8e6c5c commit e9625a1

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

src/capability.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
NO_ANSWER_STR,
1313
NON_SEED_CAPABILITIES_SCORE_DIR,
1414
SEED_CAPABILITIES_SCORE_DIR,
15+
TAB_W_SPACES,
1516
)
1617
from src.utils.data_utils import load_data
1718
from src.utils.prompts import TASK_SOLVER_SYSTEM_PROMPT
@@ -140,7 +141,10 @@ def from_dict(cls, capability_dict: Dict[str, Any], base_dir: str) -> "Capabilit
140141
f"capability_{c_dict['name']}", os.path.join(c_dir, "capability.py")
141142
)
142143
c_obj = c_module.Capability()
143-
initial_tasks = list(c_obj.repr_tasks().values())
144+
initial_tasks = [
145+
{"id": k, "problem": v["problem"], "answer": v["answer"]}
146+
for k, v in c_obj.repr_tasks().items()
147+
]
144148
template_instructions = c_obj.get_instructions({"problem": '{t["problem"]}'})
145149
template_instructions = f'f"""{template_instructions}"""'
146150

@@ -232,6 +236,95 @@ def get_repr_tasks(self) -> List[Dict[str, Any]]:
232236
)
233237
return repr_tasks
234238

239+
def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
240+
"""
241+
Add and/or update tasks for the capability.
242+
243+
Args
244+
----
245+
tasks (List[Dict[str, Any]]): A list of dictionaries containing the tasks
246+
to be added. Each task dict consists of id, problem, and answer keys.
247+
"""
248+
if not all(
249+
"id" in task and "problem" in task and "answer" in task for task in tasks
250+
):
251+
raise ValueError(
252+
"Each task must contain 'id', 'problem', and 'answer' keys."
253+
)
254+
255+
existing_task_ids = [task["id"] for task in self._data]
256+
new_task_ids = [task["id"] for task in tasks]
257+
# Keep new task for overlapping tasks
258+
# TODO: Add `overwrite` flag to update existing tasks
259+
tasks_to_keep = [
260+
task
261+
for task in self._data
262+
if task["id"]
263+
not in list(set.intersection(set(existing_task_ids), set(new_task_ids)))
264+
] + tasks
265+
# Sort by task id
266+
tasks_to_keep.sort(key=lambda x: x["id"])
267+
268+
# Check if the new task list consists of representative tasks
269+
# If yes, update the capability class python file
270+
repr_tasks = [
271+
task
272+
for task in tasks
273+
if task["id"] in self.capability_repr_class.repr_tasks()
274+
]
275+
if len(repr_tasks) > 0:
276+
partial_repr_task_ids = [task["id"] for task in repr_tasks]
277+
if len(partial_repr_task_ids) < len(
278+
self.capability_repr_class.repr_tasks()
279+
):
280+
# Get remaining tasks from existing task list
281+
for k, v in self.capability_repr_class.repr_tasks().items():
282+
if k not in partial_repr_task_ids:
283+
repr_task = {"id": k}
284+
repr_task.update(v)
285+
repr_tasks.append(repr_task)
286+
repr_tasks.sort(key=lambda x: x["id"])
287+
# Update the capability class python file
288+
# Extract str which contains the repr_tasks dictionary
289+
# TODO: Since these are hardcoded, update when the format changes
290+
prefix_str = f"def repr_tasks() -> dict[str, dict]:\n{TAB_W_SPACES}{TAB_W_SPACES}return "
291+
suffix_str = f"\n\n{TAB_W_SPACES}@staticmethod\n{TAB_W_SPACES}def get_instructions(t: dict) -> str:"
292+
prev_repr_tasks_str = self.capability_repr_class_str.split(prefix_str)[
293+
1
294+
].split(suffix_str)[0]
295+
# Restructure to match the original format
296+
repr_tasks_dict = {}
297+
for elm in repr_tasks:
298+
repr_tasks_dict[elm["id"]] = {k: v for k, v in elm.items() if k != "id"}
299+
# Replace the repr_tasks dictionary in the capability class string
300+
# with the updated one
301+
updated_repr_tasks_str = json.dumps(repr_tasks_dict, indent=4)
302+
newline = "\n"
303+
capability_repr_class_str = self.capability_repr_class_str.lstrip(
304+
f"```python{newline}"
305+
).rstrip(f"{newline}```")
306+
capability_repr_class_str = capability_repr_class_str.replace(
307+
prev_repr_tasks_str,
308+
updated_repr_tasks_str,
309+
)
310+
with open(os.path.join(self.source_dir, "capability.py"), "w") as f:
311+
f.write(capability_repr_class_str)
312+
313+
# Update the capability data in the capability json file
314+
c_dict = {
315+
"capability_name": self.name,
316+
"capability_description": self.description,
317+
"capability_domain": self.domain,
318+
"capability_instructions": self.instructions,
319+
"capability_data": tasks_to_keep,
320+
}
321+
with open(os.path.join(self.source_dir, "capability.json"), "w") as f:
322+
json.dump(c_dict, f, indent=4)
323+
324+
# Reload the capability class to reflect these changes
325+
self._load_capability_json()
326+
self._load_capability_repr_class()
327+
235328
def _to_dict(self) -> Dict[str, Any]:
236329
return {
237330
"name": self.name,

src/cfg/run_cfg.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ capabilities_cfg:
2424
num_seed_capabilities: -1
2525
num_gen_capabilities: 4
2626
num_gen_capabilities_per_run: 2
27-
num_gen_tasks_per_capability: 2
27+
num_gen_tasks_per_capability: 1
2828
task_gen_few_shot: true
2929

3030
lbo_cfg:

src/generate_tasks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def generate_tasks_using_llm(
147147
# Combine with sample tasks to get the full set of tasks
148148
start_id = len(sample_tasks) + 1
149149
all_tasks = sample_tasks + [
150-
{"id": (start_id + idx), "problem": new_tasks[idx]}
150+
{"id": str(start_id + idx), "problem": new_tasks[idx]}
151151
for idx in range(len(new_tasks))
152152
]
153153

@@ -159,3 +159,7 @@ def generate_tasks_using_llm(
159159
)
160160
print(json.dumps(solved_tasks, indent=4))
161161
print(task_solver_metadata)
162+
163+
capability.add_and_update_tasks(
164+
tasks=solved_tasks,
165+
)

0 commit comments

Comments
 (0)