1818def _sample_seed_capabilities (
1919 seed_capability_dir : str ,
2020 num_seed_capabilities : int = - 1 ,
21- include_capabilities : List [str ] | None = None ,
21+ include_capability_names : List [str ] | None = None ,
2222 random_seed : int = 42 ,
2323) -> List [Capability ]:
2424 """
@@ -31,7 +31,8 @@ def _sample_seed_capabilities(
3131 ----
3232 seed_capability_dir (str): The directory containing the seed capabilities.
3333 num_seed_capabilities (int): The number of seed capabilities to sample.
34- include_capabilities (List[str] | None): A list of capability names to include.
34+ include_capability_names (List[str] | None): A list of
35+ capability names to include.
3536 random_seed (int): The seed for the random number generator.
3637
3738 Returns
@@ -46,21 +47,21 @@ def _sample_seed_capabilities(
4647 # Select all capabilities if num_seed_capabilities is -1
4748 if num_seed_capabilities == - 1 :
4849 num_seed_capabilities = len (all_seed_capability_paths )
49- include_capabilities = None
50+ include_capability_names = None
5051
5152 # Force include some capabilities
52- if include_capabilities is not None :
53- assert num_seed_capabilities >= len (include_capabilities ), (
53+ if include_capability_names is not None :
54+ assert num_seed_capabilities >= len (include_capability_names ), (
5455 "Number of seed capabilities is less than the number of capabilities to include."
5556 )
56- for capability_name in include_capabilities :
57+ for capability_name in include_capability_names :
5758 assert os .path .exists (os .path .join (seed_capability_dir , capability_name )), (
5859 f"{ capability_name } does not exist in { seed_capability_dir } ."
5960 )
6061 capability = Capability (os .path .join (seed_capability_dir , capability_name ))
6162 sampled_seed_capabilities .append (capability )
6263 all_seed_capability_paths .remove (capability_name )
63- num_seed_capabilities -= len (include_capabilities )
64+ num_seed_capabilities -= len (include_capability_names )
6465
6566 # TODO: Enhance the selection criterion
6667 for capability_path in random .sample (
@@ -121,10 +122,10 @@ def generate_capabilities_using_llm(
121122 sys_prompt : str ,
122123 user_prompt : str ,
123124 num_seed_capabilities : int ,
124- prev_capabilities : List [str ],
125+ prev_capabilities : List [Capability ],
125126 scientist_llm_gen_cfg : Dict [str , Any ],
126127 base_capability_dir : str ,
127- include_seed_capabilities : Optional [List [str ]] = None ,
128+ include_seed_capability_names : Optional [List [str ]] = None ,
128129 ** kwargs : Any ,
129130) -> Dict [str , Any ]:
130131 """
@@ -142,25 +143,27 @@ def generate_capabilities_using_llm(
142143 sys_prompt (str): The system prompt.
143144 user_prompt (str): The user prompt.
144145 num_seed_capabilities (int): The number of seed capabilities to use.
145- prev_capabilities (List[str ]): The list of previously
146- generated capability names .
146+ prev_capabilities (List[Capability ]): The list of previously
147+ generated capabilities .
147148 scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
148149 for the scientist LLM.
149150 base_capability_dir (str): The base directory to store
150151 the generated capabilities for the specified domain.
151- include_seed_capabilities (List[str] | None): A list of seed capability
152+ include_seed_capability_names (List[str] | None): A list of seed capability
152153 names to include in the generation process.
154+ **kwargs (Any): Additional keyword arguments.
153155
154156 Returns
155157 -------
156- List[str]: The generated capability names.
158+ Dict[str, Any]: A dictionary containing the generated capabilities
159+ and metadata about the generation process.
157160 """
158161 # Select seed capabilities
159162 seed_capability_dir = os .path .join (BASE_ARTIFACTS_DIR , "seed_capabilities" , domain )
160163 seed_capabilities = _sample_seed_capabilities (
161164 seed_capability_dir = seed_capability_dir ,
162165 num_seed_capabilities = num_seed_capabilities ,
163- include_capabilities = include_seed_capabilities ,
166+ include_capability_names = include_seed_capability_names ,
164167 )
165168 # Get capability JSON strings (without scores)
166169 seed_capabilities_repr = [
@@ -170,7 +173,7 @@ def generate_capabilities_using_llm(
170173 # LLM input
171174 user_prompt = user_prompt .format (
172175 seed_capabilities = "\n " .join (seed_capabilities_repr ),
173- prev_capabilities = "\n " .join (prev_capabilities ),
176+ prev_capabilities = "\n " .join ([ elm . name for elm in prev_capabilities ] ),
174177 domain = domain ,
175178 num_gen_capabilities = num_capabilities ,
176179 )
@@ -193,10 +196,9 @@ def generate_capabilities_using_llm(
193196 Capability .from_dict (capability_dict = capability , base_dir = base_capability_dir )
194197 for capability in gen_capabilities
195198 ]
196- gen_capabilities_names = [elm .name for elm in gen_capabilities ]
197199
198200 return {
199- "capabilities" : gen_capabilities_names ,
201+ "capabilities" : gen_capabilities ,
200202 "metadata" : {
201203 "model" : scientist_llm .get_model_name (),
202204 "thought" : parsed_response ["thought" ],
@@ -206,20 +208,20 @@ def generate_capabilities_using_llm(
206208
207209
208210def filter_capabilities (
209- capabilities : List [str ],
210- ) -> List [str ]:
211+ capabilities : List [Capability ],
212+ ) -> List [Capability ]:
211213 """
212214 Filter capabilities based on multiple criterion.
213215
214216 Remove repeated, irrelevant, and ill-formed capabilities.
215217
216218 Args
217219 ----
218- capabilities (List[str ]): The list of capabilities.
220+ capabilities (List[Capability ]): The list of capabilities.
219221
220222 Returns
221223 -------
222- List[str ]: The filtered capability names .
224+ List[Capability ]: The list of remaining capabilities .
223225 """
224226 # TODO: Implement capability filtering
225227 return capabilities
@@ -232,9 +234,9 @@ def generate_capabilities(
232234 scientist_llm : Model ,
233235 num_seed_capabilities : int ,
234236 scientist_llm_gen_cfg : Dict [str , Any ],
235- include_seed_capabilities : Optional [List [str ]] = None ,
237+ include_seed_capability_names : Optional [List [str ]] = None ,
236238 ** kwargs : Any ,
237- ) -> List [str ]:
239+ ) -> List [Capability ]:
238240 """
239241 Generate initial capabilities for the specified domain.
240242
@@ -247,12 +249,12 @@ def generate_capabilities(
247249 num_seed_capabilities (int): The number of seed capabilities to use.
248250 scientist_llm_gen_cfg (Dict[str, Any]): The generation configuration
249251 for the scientist LLM.
250- include_seed_capabilities (List[str] | None): A list of seed capability
252+ include_seed_capability_names (List[str] | None): A list of seed capability
251253 names to include in the generation process.
252254
253255 Returns
254256 -------
255- List[str ]: The generated capability names .
257+ List[Capability ]: The generated capabilities .
256258 """
257259 num_runs = int (np .ceil (num_capabilities / num_capabilities_per_run ))
258260 gen_capabilities = []
@@ -268,10 +270,7 @@ def generate_capabilities(
268270 base_capability_dir = os .path .join (BASE_ARTIFACTS_DIR , "capabilities" , domain )
269271
270272 # Fetch previously generated capabilities, if any
271- prev_capabilities = [
272- elm .name
273- for elm in _get_previous_capabilities (capability_dir = base_capability_dir )
274- ]
273+ prev_capabilities = _get_previous_capabilities (capability_dir = base_capability_dir )
275274
276275 for run_id in range (num_runs ):
277276 print ("Run ID:" , run_id )
@@ -286,7 +285,7 @@ def generate_capabilities(
286285 prev_capabilities = prev_capabilities ,
287286 scientist_llm_gen_cfg = scientist_llm_gen_cfg ,
288287 base_capability_dir = base_capability_dir ,
289- include_seed_capabilities = include_seed_capabilities ,
288+ include_seed_capability_names = include_seed_capability_names ,
290289 ** kwargs ,
291290 )
292291 gen_capabilities .extend (response ["capabilities" ])
0 commit comments