Skip to content

Commit 0a49191

Browse files
authored
Add hierarchical capability generation and task verification
2 parents 6529a5a + e9614cb commit 0a49191

File tree

8 files changed

+404
-75
lines changed

8 files changed

+404
-75
lines changed

src/capability.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def from_dict(cls, capability_dict: Dict[str, Any], base_dir: str) -> "Capabilit
183183
"capability_name": c_dict.pop("name"),
184184
"capability_description": c_dict.pop("description"),
185185
"capability_domain": c_dict.pop("domain"),
186+
"capability_area": c_dict.pop("area", None),
186187
"capability_instructions": template_instructions,
187188
"capability_data": initial_tasks,
188189
}
@@ -199,6 +200,7 @@ def _load_capability_json(self) -> None:
199200
self.description = _cfg["capability_description"]
200201
self.domain = _cfg["capability_domain"]
201202
self.instructions = _cfg["capability_instructions"]
203+
self.area = _cfg.get("capability_area", None)
202204
# TODO: Store data is stored in json or elsewhere?
203205
self._data: List[Dict[str, Any]] = _cfg["capability_data"]
204206
# Check if the capability is a seed capability, use source_dataset as indicator
@@ -266,14 +268,21 @@ def get_repr_tasks(self) -> List[Dict[str, Any]]:
266268
)
267269
return repr_tasks
268270

269-
def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
271+
def add_and_update_tasks(
272+
self,
273+
tasks: List[Dict[str, Any]],
274+
failed_tasks: List[Dict[str, Any]] | None = None,
275+
) -> None:
270276
"""
271277
Add and/or update tasks for the capability.
272278
273279
Args
274280
----
275281
tasks (List[Dict[str, Any]]): A list of dictionaries containing the tasks
276282
to be added. Each task dict consists of id, problem, and answer keys.
283+
failed_tasks (List[Dict[str, Any]]): A list of dictionaries
284+
containing the tasks that failed to be solved.
285+
Each task dict consists of id, problem, and answer keys.
277286
"""
278287
if not all(
279288
"id" in task and "problem" in task and "answer" in task for task in tasks
@@ -344,9 +353,17 @@ def add_and_update_tasks(self, tasks: List[Dict[str, Any]]) -> None:
344353
"capability_name": self.name,
345354
"capability_description": self.description,
346355
"capability_domain": self.domain,
356+
"capability_area": self.area,
347357
"capability_instructions": self.instructions,
348358
"capability_data": tasks_to_keep,
349359
}
360+
# TODO: Handle edge cases for failed tasks
361+
if failed_tasks:
362+
c_dict.update(
363+
{
364+
"capability_failed_data": failed_tasks,
365+
}
366+
)
350367
with open(os.path.join(self.source_dir, "capability.json"), "w") as f:
351368
json.dump(c_dict, f, indent=4)
352369

src/cfg/run_cfg.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ scientist_llm:
1414
judge_llm:
1515
temperature: 1.0
1616
max_tokens: 64
17+
task_verify:
18+
temperature: 0.7
19+
max_tokens: 64
1720
local_launch_cfg:
1821
# Number of threads to use for local LLM
1922
max_num_seqs: 1
@@ -48,15 +51,21 @@ capabilities_cfg:
4851
results_dir: gs://ace-artifacts
4952
inspect_evals_dir: /fs01/projects/aieng/public/ace/inspect_evals/src/ace_evals
5053
domain: math
54+
# Method used to generate capabilities
55+
method: "hierarchical"
5156
# Number of seed capabilities to use for initial capability generation
5257
# Set to -1 to use all seed capabilities
5358
num_seed_capabilities: 1
5459
# Number of initial capabilities to generate using the scientist LLM
55-
num_gen_capabilities: 1
60+
num_gen_capabilities: 2
61+
# Number of capability areas to generate
62+
num_capability_areas: 2
5663
# Number of initial capabilities to generate per run
5764
num_gen_capabilities_per_run: 1
5865
# Number of tasks to generate for each capability
5966
num_gen_tasks_per_capability: 1
67+
# Buffer for task generation
68+
num_gen_tasks_buffer: 0.2
6069
# Set this flag to true to use representative tasks
6170
# as few shot examples for task generation
6271
task_gen_few_shot: true

src/generate_capabilities.py

Lines changed: 159 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
from src.utils import constants
1818
from src.utils.capability_utils import extract_and_parse_response
1919
from src.utils.prompts import (
20+
CAPABILITY_AREAS_GENERATION_RESPONSE_JSON_FORMAT,
2021
CAPABILITY_GENERATION_SYSTEM_PROMPT,
2122
CAPABILITY_GENERATION_USER_PROMPT,
23+
HIERARCHICAL_CAPABILITY_AREAS_GENERATION_USER_PROMPT,
24+
HIERARCHICAL_CAPABILITY_GENERATION_USER_PROMPT,
2225
)
2326

2427

@@ -104,6 +107,7 @@ def _sample_seed_capabilities(
104107

105108
def _get_previous_capabilities(
106109
capability_dir: str,
110+
capability_area: str | None = None,
107111
) -> List[Capability]:
108112
"""
109113
Get the previously generated capabilities for the specified domain.
@@ -121,6 +125,8 @@ def _get_previous_capabilities(
121125
prev_capabilities = []
122126
for capability_path in os.listdir(capability_dir):
123127
capability = Capability(os.path.join(capability_dir, capability_path))
128+
if capability_area is not None and capability.area != capability_area:
129+
continue
124130
prev_capabilities.append(capability)
125131
return prev_capabilities
126132

@@ -157,6 +163,7 @@ def generate_capabilities_using_llm(
157163
base_capability_dir: str,
158164
include_seed_capability_names: Optional[List[str]] = None,
159165
exclude_seed_capability_names: Optional[List[str]] = None,
166+
capability_area: str | None = None,
160167
**kwargs: Any,
161168
) -> Dict[str, Any]:
162169
"""
@@ -185,6 +192,7 @@ def generate_capabilities_using_llm(
185192
names to include in the generation process.
186193
exclude_seed_capability_names (List[str] | None): A list of seed capability
187194
names to exclude from the generation process.
195+
capability_area (str | None): The capability area for the generation
188196
**kwargs (Any): Additional keyword arguments.
189197
190198
Returns
@@ -226,6 +234,10 @@ def generate_capabilities_using_llm(
226234

227235
parsed_response = extract_and_parse_response(response)
228236
gen_capabilities = parsed_response["parsed_response"]
237+
if capability_area is not None:
238+
# Add the capability area to the generated capabilities
239+
for capability in gen_capabilities:
240+
capability["area"] = capability_area
229241
gen_capabilities = [
230242
Capability.from_dict(capability_dict=capability, base_dir=base_capability_dir)
231243
for capability in gen_capabilities
@@ -366,13 +378,72 @@ def filter_capabilities(
366378
return [capabilities[i] for i in remaining_indices]
367379

368380

381+
def generate_capability_areas(
382+
domain: str,
383+
num_areas: int,
384+
num_capabilities_per_area: int,
385+
scientist_llm: Model,
386+
user_prompt: str,
387+
scientist_llm_gen_cfg: Dict[str, Any],
388+
sys_prompt: str | None = None,
389+
) -> Dict[str, Any]:
390+
"""
391+
Generate capability areas for the specified domain.
392+
393+
Args
394+
----
395+
domain (str): The domain name.
396+
num_areas (int): The number of capability areas to generate.
397+
num_capabilities_per_area (int): The number of capabilities per area.
398+
scientist_llm (Model): The scientist LLM model.
399+
user_prompt (str): The user prompt for generating capability areas.
400+
scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
401+
for the scientist LLM.
402+
sys_prompt (str | None): The system prompt for the scientist LLM.
403+
404+
Returns
405+
-------
406+
Dict[str, Any]: A dictionary containing the generated capability areas
407+
and metadata about the generation process.
408+
"""
409+
# Generate output using the model with specified generation arguments
410+
user_prompt = user_prompt.format(
411+
num_areas=num_areas,
412+
num_capabilities_per_area=num_capabilities_per_area,
413+
domain=domain,
414+
response_json_format=CAPABILITY_AREAS_GENERATION_RESPONSE_JSON_FORMAT,
415+
)
416+
response, metadata = scientist_llm.generate(
417+
sys_prompt=sys_prompt if sys_prompt else "",
418+
user_prompt=user_prompt,
419+
generation_config=scientist_llm_gen_cfg,
420+
)
421+
422+
# Print the output
423+
print(f"Model: {scientist_llm.get_model_name()}")
424+
print(f"Output:\n\n{response}\n\n")
425+
print(f"Metadata: {metadata}")
426+
427+
parsed_response = extract_and_parse_response(response, has_thought=False)
428+
capability_areas = parsed_response["parsed_response"]
429+
430+
return {
431+
"capability_areas": capability_areas,
432+
"metadata": {
433+
"model": scientist_llm.get_model_name(),
434+
"api_metadata": metadata,
435+
},
436+
}
437+
438+
369439
def generate_capabilities(
370440
domain: str,
371441
num_capabilities: int,
372442
num_capabilities_per_run: int,
373443
scientist_llm: Model,
374444
num_seed_capabilities: int,
375445
scientist_llm_gen_cfg: Dict[str, Any],
446+
method: str = "flat",
376447
include_seed_capability_names: Optional[List[str]] = None,
377448
exclude_seed_capability_names: Optional[List[str]] = None,
378449
**kwargs: Any,
@@ -389,6 +460,8 @@ def generate_capabilities(
389460
num_seed_capabilities (int): The number of seed capabilities to use.
390461
scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
391462
for the scientist LLM.
463+
method (str): The method to use for generating capabilities.
464+
Choose from "flat" or "hierarchical".
392465
include_seed_capability_names (List[str] | None): A list of seed capability
393466
names to include in the generation process.
394467
exclude_seed_capability_names (List[str] | None): A list of seed capability
@@ -398,7 +471,6 @@ def generate_capabilities(
398471
-------
399472
List[Capability]: The generated capabilities.
400473
"""
401-
num_runs = int(np.ceil(num_capabilities / num_capabilities_per_run))
402474
gen_capabilities = []
403475
run_metadata = []
404476

@@ -413,42 +485,98 @@ def generate_capabilities(
413485
constants.BASE_ARTIFACTS_DIR, "capabilities", domain
414486
)
415487

416-
# Fetch previously generated capabilities, if any
417-
prev_capabilities = _get_previous_capabilities(capability_dir=base_capability_dir)
418-
419-
# Add all seed capabilities to the list of prev_capabilities
420-
seed_capability_dir = os.path.join(
421-
constants.BASE_ARTIFACTS_DIR, "seed_capabilities", domain
422-
)
423-
prev_capabilities.extend(
424-
_sample_seed_capabilities(
425-
seed_capability_dir=seed_capability_dir,
426-
num_seed_capabilities=-1,
488+
if method == "hierarchical":
489+
assert "num_capability_areas" in kwargs, (
490+
"`num_capability_areas` should be specified for hierarchical generation."
427491
)
428-
)
492+
num_capability_areas = kwargs["num_capability_areas"]
493+
assert num_capabilities >= num_capability_areas, (
494+
"Number of capabilities should be greater than or equal to the number of capability areas, "
495+
+ "so that each area can have at least one capability."
496+
)
497+
# Uniformly distribute num_capabilities across num_capability_areas
498+
num_capabilities_per_area = [
499+
num_capabilities // num_capability_areas
500+
] * num_capability_areas
501+
for i in range(num_capabilities % num_capability_areas):
502+
num_capabilities_per_area[i] += 1
503+
num_runs = [
504+
int(np.ceil(num / num_capabilities_per_run))
505+
for num in num_capabilities_per_area
506+
]
429507

430-
for run_id in range(num_runs):
431-
print("Run ID:", run_id)
432-
# Generate capabilities using the scientist LLM
433-
response = generate_capabilities_using_llm(
508+
# Generate capability areas for the specified domain
509+
response = generate_capability_areas(
434510
domain=domain,
435-
num_capabilities=num_capabilities_per_run,
511+
num_areas=kwargs["num_capability_areas"],
512+
num_capabilities_per_area=num_capabilities_per_area[0],
436513
scientist_llm=scientist_llm,
437-
sys_prompt=CAPABILITY_GENERATION_SYSTEM_PROMPT,
438-
user_prompt=CAPABILITY_GENERATION_USER_PROMPT,
439-
num_seed_capabilities=num_seed_capabilities,
440-
seed_capability_dir=seed_capability_dir,
441-
prev_capabilities=prev_capabilities,
514+
user_prompt=HIERARCHICAL_CAPABILITY_AREAS_GENERATION_USER_PROMPT,
442515
scientist_llm_gen_cfg=scientist_llm_gen_cfg,
443-
base_capability_dir=base_capability_dir,
444-
include_seed_capability_names=include_seed_capability_names,
445-
exclude_seed_capability_names=exclude_seed_capability_names,
446-
**kwargs,
447516
)
448-
gen_capabilities.extend(response["capabilities"])
449-
run_metadata.append(response["metadata"])
517+
capability_areas = response["capability_areas"]
518+
else:
519+
num_capabilities_per_area = [num_capabilities]
520+
num_runs = [int(np.ceil(num_capabilities / num_capabilities_per_run))]
521+
# No capability areas for flat generation, use the domain as the area
522+
capability_areas = [domain]
523+
524+
for idx, capability_area in enumerate(capability_areas):
525+
if method == "hierarchical":
526+
print(f"Generating capabilities for area: {capability_area}")
527+
# Fetch previously generated capabilities, if any
528+
prev_capabilities = _get_previous_capabilities(
529+
capability_dir=base_capability_dir, capability_area=capability_area
530+
)
531+
user_prompt = HIERARCHICAL_CAPABILITY_GENERATION_USER_PROMPT.format(
532+
capability_area=capability_area,
533+
)
534+
else:
535+
prev_capabilities = _get_previous_capabilities(
536+
capability_dir=base_capability_dir
537+
)
538+
user_prompt = CAPABILITY_GENERATION_USER_PROMPT
539+
540+
# Add all seed capabilities to the list of prev_capabilities
541+
seed_capability_dir = os.path.join(
542+
constants.BASE_ARTIFACTS_DIR, "seed_capabilities", domain
543+
)
544+
prev_capabilities.extend(
545+
_sample_seed_capabilities(
546+
seed_capability_dir=seed_capability_dir,
547+
num_seed_capabilities=-1,
548+
)
549+
)
550+
551+
num_capabilities_left = num_capabilities_per_area[idx]
552+
for run_id in range(num_runs[idx]):
553+
print("Run ID:", run_id)
554+
# Generate capabilities using the scientist LLM
555+
556+
response = generate_capabilities_using_llm(
557+
domain=domain,
558+
num_capabilities=min(
559+
num_capabilities_per_run,
560+
num_capabilities_left,
561+
),
562+
scientist_llm=scientist_llm,
563+
sys_prompt=CAPABILITY_GENERATION_SYSTEM_PROMPT,
564+
user_prompt=user_prompt,
565+
num_seed_capabilities=num_seed_capabilities,
566+
seed_capability_dir=seed_capability_dir,
567+
prev_capabilities=prev_capabilities,
568+
scientist_llm_gen_cfg=scientist_llm_gen_cfg,
569+
base_capability_dir=base_capability_dir,
570+
include_seed_capability_names=include_seed_capability_names,
571+
exclude_seed_capability_names=exclude_seed_capability_names,
572+
capability_area=capability_area if method == "hierarchical" else None,
573+
**kwargs,
574+
)
575+
gen_capabilities.extend(response["capabilities"])
576+
num_capabilities_left -= len(response["capabilities"])
577+
run_metadata.append(response["metadata"])
450578

451-
# Update the list of previously generated capabilities
452-
prev_capabilities.extend(response["capabilities"])
579+
# Update the list of previously generated capabilities
580+
prev_capabilities.extend(response["capabilities"])
453581

454582
return gen_capabilities

0 commit comments

Comments
 (0)