1717from src .utils import constants
1818from src .utils .capability_utils import extract_and_parse_response
1919from src .utils .prompts import (
20+ CAPABILITY_AREAS_GENERATION_RESPONSE_JSON_FORMAT ,
2021 CAPABILITY_GENERATION_SYSTEM_PROMPT ,
2122 CAPABILITY_GENERATION_USER_PROMPT ,
23+ HIERARCHICAL_CAPABILITY_AREAS_GENERATION_USER_PROMPT ,
24+ HIERARCHICAL_CAPABILITY_GENERATION_USER_PROMPT ,
2225)
2326
2427
@@ -104,6 +107,7 @@ def _sample_seed_capabilities(
104107
105108def _get_previous_capabilities (
106109 capability_dir : str ,
110+ capability_area : str | None = None ,
107111) -> List [Capability ]:
108112 """
109113 Get the previously generated capabilities for the specified domain.
@@ -121,6 +125,8 @@ def _get_previous_capabilities(
121125 prev_capabilities = []
122126 for capability_path in os .listdir (capability_dir ):
123127 capability = Capability (os .path .join (capability_dir , capability_path ))
128+ if capability_area is not None and capability .area != capability_area :
129+ continue
124130 prev_capabilities .append (capability )
125131 return prev_capabilities
126132
@@ -157,6 +163,7 @@ def generate_capabilities_using_llm(
157163 base_capability_dir : str ,
158164 include_seed_capability_names : Optional [List [str ]] = None ,
159165 exclude_seed_capability_names : Optional [List [str ]] = None ,
166+ capability_area : str | None = None ,
160167 ** kwargs : Any ,
161168) -> Dict [str , Any ]:
162169 """
@@ -185,6 +192,7 @@ def generate_capabilities_using_llm(
185192 names to include in the generation process.
186193 exclude_seed_capability_names (List[str] | None): A list of seed capability
187194 names to exclude from the generation process.
195+ capability_area (str | None): The capability area for the generation
188196 **kwargs (Any): Additional keyword arguments.
189197
190198 Returns
@@ -226,6 +234,10 @@ def generate_capabilities_using_llm(
226234
227235 parsed_response = extract_and_parse_response (response )
228236 gen_capabilities = parsed_response ["parsed_response" ]
237+ if capability_area is not None :
238+ # Add the capability area to the generated capabilities
239+ for capability in gen_capabilities :
240+ capability ["area" ] = capability_area
229241 gen_capabilities = [
230242 Capability .from_dict (capability_dict = capability , base_dir = base_capability_dir )
231243 for capability in gen_capabilities
@@ -366,13 +378,72 @@ def filter_capabilities(
366378 return [capabilities [i ] for i in remaining_indices ]
367379
368380
381+ def generate_capability_areas (
382+ domain : str ,
383+ num_areas : int ,
384+ num_capabilities_per_area : int ,
385+ scientist_llm : Model ,
386+ user_prompt : str ,
387+ scientist_llm_gen_cfg : Dict [str , Any ],
388+ sys_prompt : str | None = None ,
389+ ) -> Dict [str , Any ]:
390+ """
391+ Generate capability areas for the specified domain.
392+
393+ Args
394+ ----
395+ domain (str): The domain name.
396+ num_areas (int): The number of capability areas to generate.
397+ num_capabilities_per_area (int): The number of capabilities per area.
398+ scientist_llm (Model): The scientist LLM model.
399+ user_prompt (str): The user prompt for generating capability areas.
400+ scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
401+ for the scientist LLM.
402+ sys_prompt (str | None): The system prompt for the scientist LLM.
403+
404+ Returns
405+ -------
406+ Dict[str, Any]: A dictionary containing the generated capability areas
407+ and metadata about the generation process.
408+ """
409+ # Generate output using the model with specified generation arguments
410+ user_prompt = user_prompt .format (
411+ num_areas = num_areas ,
412+ num_capabilities_per_area = num_capabilities_per_area ,
413+ domain = domain ,
414+ response_json_format = CAPABILITY_AREAS_GENERATION_RESPONSE_JSON_FORMAT ,
415+ )
416+ response , metadata = scientist_llm .generate (
417+ sys_prompt = sys_prompt if sys_prompt else "" ,
418+ user_prompt = user_prompt ,
419+ generation_config = scientist_llm_gen_cfg ,
420+ )
421+
422+ # Print the output
423+ print (f"Model: { scientist_llm .get_model_name ()} " )
424+ print (f"Output:\n \n { response } \n \n " )
425+ print (f"Metadata: { metadata } " )
426+
427+ parsed_response = extract_and_parse_response (response , has_thought = False )
428+ capability_areas = parsed_response ["parsed_response" ]
429+
430+ return {
431+ "capability_areas" : capability_areas ,
432+ "metadata" : {
433+ "model" : scientist_llm .get_model_name (),
434+ "api_metadata" : metadata ,
435+ },
436+ }
437+
438+
369439def generate_capabilities (
370440 domain : str ,
371441 num_capabilities : int ,
372442 num_capabilities_per_run : int ,
373443 scientist_llm : Model ,
374444 num_seed_capabilities : int ,
375445 scientist_llm_gen_cfg : Dict [str , Any ],
446+ method : str = "flat" ,
376447 include_seed_capability_names : Optional [List [str ]] = None ,
377448 exclude_seed_capability_names : Optional [List [str ]] = None ,
378449 ** kwargs : Any ,
@@ -389,6 +460,8 @@ def generate_capabilities(
389460 num_seed_capabilities (int): The number of seed capabilities to use.
390461 scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
391462 for the scientist LLM.
463+ method (str): The method to use for generating capabilities.
464+ Choose from "flat" or "hierarchical".
392465 include_seed_capability_names (List[str] | None): A list of seed capability
393466 names to include in the generation process.
394467 exclude_seed_capability_names (List[str] | None): A list of seed capability
@@ -398,7 +471,6 @@ def generate_capabilities(
398471 -------
399472 List[Capability]: The generated capabilities.
400473 """
401- num_runs = int (np .ceil (num_capabilities / num_capabilities_per_run ))
402474 gen_capabilities = []
403475 run_metadata = []
404476
@@ -413,42 +485,98 @@ def generate_capabilities(
413485 constants .BASE_ARTIFACTS_DIR , "capabilities" , domain
414486 )
415487
416- # Fetch previously generated capabilities, if any
417- prev_capabilities = _get_previous_capabilities (capability_dir = base_capability_dir )
418-
419- # Add all seed capabilities to the list of prev_capabilities
420- seed_capability_dir = os .path .join (
421- constants .BASE_ARTIFACTS_DIR , "seed_capabilities" , domain
422- )
423- prev_capabilities .extend (
424- _sample_seed_capabilities (
425- seed_capability_dir = seed_capability_dir ,
426- num_seed_capabilities = - 1 ,
488+ if method == "hierarchical" :
489+ assert "num_capability_areas" in kwargs , (
490+ "`num_capability_areas` should be specified for hierarchical generation."
427491 )
428- )
492+ num_capability_areas = kwargs ["num_capability_areas" ]
493+ assert num_capabilities >= num_capability_areas , (
494+ "Number of capabilities should be greater than or equal to the number of capability areas, "
495+ + "so that each area can have at least one capability."
496+ )
497+ # Uniformly distribute num_capabilities across num_capability_areas
498+ num_capabilities_per_area = [
499+ num_capabilities // num_capability_areas
500+ ] * num_capability_areas
501+ for i in range (num_capabilities % num_capability_areas ):
502+ num_capabilities_per_area [i ] += 1
503+ num_runs = [
504+ int (np .ceil (num / num_capabilities_per_run ))
505+ for num in num_capabilities_per_area
506+ ]
429507
430- for run_id in range (num_runs ):
431- print ("Run ID:" , run_id )
432- # Generate capabilities using the scientist LLM
433- response = generate_capabilities_using_llm (
508+ # Generate capability areas for the specified domain
509+ response = generate_capability_areas (
434510 domain = domain ,
435- num_capabilities = num_capabilities_per_run ,
511+ num_areas = kwargs ["num_capability_areas" ],
512+ num_capabilities_per_area = num_capabilities_per_area [0 ],
436513 scientist_llm = scientist_llm ,
437- sys_prompt = CAPABILITY_GENERATION_SYSTEM_PROMPT ,
438- user_prompt = CAPABILITY_GENERATION_USER_PROMPT ,
439- num_seed_capabilities = num_seed_capabilities ,
440- seed_capability_dir = seed_capability_dir ,
441- prev_capabilities = prev_capabilities ,
514+ user_prompt = HIERARCHICAL_CAPABILITY_AREAS_GENERATION_USER_PROMPT ,
442515 scientist_llm_gen_cfg = scientist_llm_gen_cfg ,
443- base_capability_dir = base_capability_dir ,
444- include_seed_capability_names = include_seed_capability_names ,
445- exclude_seed_capability_names = exclude_seed_capability_names ,
446- ** kwargs ,
447516 )
448- gen_capabilities .extend (response ["capabilities" ])
449- run_metadata .append (response ["metadata" ])
517+ capability_areas = response ["capability_areas" ]
518+ else :
519+ num_capabilities_per_area = [num_capabilities ]
520+ num_runs = [int (np .ceil (num_capabilities / num_capabilities_per_run ))]
521+ # No capability areas for flat generation, use the domain as the area
522+ capability_areas = [domain ]
523+
524+ for idx , capability_area in enumerate (capability_areas ):
525+ if method == "hierarchical" :
526+ print (f"Generating capabilities for area: { capability_area } " )
527+ # Fetch previously generated capabilities, if any
528+ prev_capabilities = _get_previous_capabilities (
529+ capability_dir = base_capability_dir , capability_area = capability_area
530+ )
531+ user_prompt = HIERARCHICAL_CAPABILITY_GENERATION_USER_PROMPT .format (
532+ capability_area = capability_area ,
533+ )
534+ else :
535+ prev_capabilities = _get_previous_capabilities (
536+ capability_dir = base_capability_dir
537+ )
538+ user_prompt = CAPABILITY_GENERATION_USER_PROMPT
539+
540+ # Add all seed capabilities to the list of prev_capabilities
541+ seed_capability_dir = os .path .join (
542+ constants .BASE_ARTIFACTS_DIR , "seed_capabilities" , domain
543+ )
544+ prev_capabilities .extend (
545+ _sample_seed_capabilities (
546+ seed_capability_dir = seed_capability_dir ,
547+ num_seed_capabilities = - 1 ,
548+ )
549+ )
550+
551+ num_capabilities_left = num_capabilities_per_area [idx ]
552+ for run_id in range (num_runs [idx ]):
553+ print ("Run ID:" , run_id )
554+ # Generate capabilities using the scientist LLM
555+
556+ response = generate_capabilities_using_llm (
557+ domain = domain ,
558+ num_capabilities = min (
559+ num_capabilities_per_run ,
560+ num_capabilities_left ,
561+ ),
562+ scientist_llm = scientist_llm ,
563+ sys_prompt = CAPABILITY_GENERATION_SYSTEM_PROMPT ,
564+ user_prompt = user_prompt ,
565+ num_seed_capabilities = num_seed_capabilities ,
566+ seed_capability_dir = seed_capability_dir ,
567+ prev_capabilities = prev_capabilities ,
568+ scientist_llm_gen_cfg = scientist_llm_gen_cfg ,
569+ base_capability_dir = base_capability_dir ,
570+ include_seed_capability_names = include_seed_capability_names ,
571+ exclude_seed_capability_names = exclude_seed_capability_names ,
572+ capability_area = capability_area if method == "hierarchical" else None ,
573+ ** kwargs ,
574+ )
575+ gen_capabilities .extend (response ["capabilities" ])
576+ num_capabilities_left -= len (response ["capabilities" ])
577+ run_metadata .append (response ["metadata" ])
450578
451- # Update the list of previously generated capabilities
452- prev_capabilities .extend (response ["capabilities" ])
579+ # Update the list of previously generated capabilities
580+ prev_capabilities .extend (response ["capabilities" ])
453581
454582 return gen_capabilities
0 commit comments