44import logging
55import shutil
66from pathlib import Path
7+ from typing import Any
78
89import datasets
910import yaml
1213from fhda .data_analysis_env import DataAnalysisEnv
1314from fhda .utils import NBLanguage , collect_notebook_stats , load_mcq
1415from huggingface_hub import hf_hub_download
15- from ldp .agent import AgentConfig
16+ from ldp .agent import Agent , AgentConfig
1617from ldp .alg .rollout import RolloutManager
17- from ldp .data_structures import Trajectory
18+ from ldp .data_structures import Trajectory , Transition
1819
1920logger = logging .getLogger (__name__ )
2021
2122
22- class TraceGenerator :
23- def __init__ (self ):
23+ class TrajectoryGenerator :
24+ """
25+ Generator for creating and storing agent trajectories on data analysis tasks.
26+
27+ This class handles the full pipeline of loading benchmark capsules, setting up
28+ environments, running agents through these environments, and storing the resulting
29+ trajectories.
30+ """
31+
32+ def __init__ (self ) -> None :
33+ """Initialize the TrajectoryGenerator with config and create necessary directories."""
2434 self .config = self .load_config ()
2535 # Create directories
2636 self .config ["local_workspace_dir" ].mkdir (parents = True , exist_ok = True )
27- self .config ["local_traces_dir " ].mkdir (parents = True , exist_ok = True )
37+ self .config ["local_trajectories_dir " ].mkdir (parents = True , exist_ok = True )
2838 self .config ["local_data_folder" ].mkdir (parents = True , exist_ok = True )
2939
3040 # TODO: Move to utils and use a yaml loader package
31- def load_config (self ):
41+ def load_config (self ) -> dict [str , Any ]:
42+ """
43+ Load and process configuration from the config.yaml file.
44+
45+ Returns:
46+ Dict[str, Any]: Processed configuration dictionary
47+ """
3248 config_path = Path (__file__ ).parent / "config.yaml"
3349 with open (config_path , encoding = "utf-8" ) as f :
3450 config = yaml .safe_load (f )
@@ -60,12 +76,22 @@ def load_config(self):
6076 "base_prompt" : base_prompt ,
6177 "eval_mode" : EvalAnswerMode [config ["capsule" ]["eval_mode" ]],
6278 "local_workspace_dir" : Path (config ["paths" ]["workspace_dir" ]).absolute (),
63- "local_traces_dir" : Path (config ["paths" ]["traces_dir" ]).absolute (),
79+ "local_trajectories_dir" : Path (
80+ config ["paths" ]["trajectories_dir" ]
81+ ).absolute (),
6482 "local_data_folder" : Path (config ["paths" ]["data_folder" ]).absolute (),
6583 "hf_repo_id" : config ["paths" ]["hf_repo_id" ],
84+ "rollout_type" : config ["rollout" ].get ("type" , "vanilla" ),
85+ "avoid_images" : config ["capsule" ]["avoid_images" ],
6686 }
6787
68- async def process_capsule (self , capsule ):
88+ async def process_capsule (self , capsule : dict [str , Any ]) -> None :
89+ """
90+ Process a single benchmark capsule by downloading and extracting necessary files.
91+
92+ Args:
93+ capsule: Dictionary containing capsule information
94+ """
6995 zip_filename = capsule ["data_folder" ]
7096 extract_dir = self .config ["local_data_folder" ] / zip_filename .replace (
7197 ".zip" , ""
@@ -92,7 +118,13 @@ async def process_capsule(self, capsule):
92118 await asyncio .to_thread (self ._extract_and_process_files , zip_path , extract_dir )
93119 capsule ["local_data_folder" ] = extract_dir
94120
95- async def load_bixbench (self ) -> datasets .Dataset :
121+ async def load_bixbench (self ) -> list [dict [str , Any ]]:
122+ """
123+ Load BixBench dataset and process all capsules.
124+
125+ Returns:
126+ List[Dict[str, Any]]: List of processed benchmark capsules
127+ """
96128 bixbench = datasets .load_dataset (
97129 self .config ["hf_repo_id" ], split = "train"
98130 ).to_list ()
@@ -103,8 +135,14 @@ async def load_bixbench(self) -> datasets.Dataset:
103135
104136 return bixbench
105137
106- def _extract_and_process_files (self , zip_path : Path , extract_dir : Path ):
107- """Helper method to extract and process zip files."""
138+ def _extract_and_process_files (self , zip_path : Path , extract_dir : Path ) -> None :
139+ """
140+ Extract and process zip files for a capsule.
141+
142+ Args:
143+ zip_path: Path to the zip file
144+ extract_dir: Directory to extract files to
145+ """
108146 # Extract the zip file
109147 shutil .unpack_archive (zip_path , extract_dir )
110148
@@ -140,6 +178,13 @@ def _extract_and_process_files(self, zip_path: Path, extract_dir: Path):
140178 async def store_trajectory (
141179 self , trajectory : Trajectory , env : DataAnalysisEnv
142180 ) -> None :
181+ """
182+ Store trajectory and environment information to disk.
183+
184+ Args:
185+ trajectory: The trajectory to store
186+ env: The environment that generated the trajectory
187+ """
143188 extract = {
144189 "problem_id" : env .problem_id ,
145190 "agent_answer" : env .state .answer ,
@@ -160,7 +205,7 @@ async def store_trajectory(
160205 }
161206
162207 # Download run metadata
163- with (self .config ["local_traces_dir " ] / f"{ env .problem_id } .json" ).open (
208+ with (self .config ["local_trajectories_dir " ] / f"{ env .problem_id } .json" ).open (
164209 "w"
165210 ) as f :
166211 json .dump (
@@ -170,10 +215,19 @@ async def store_trajectory(
170215 )
171216 # Download run trajectory
172217 await trajectory .to_jsonl (
173- self .config ["local_traces_dir " ] / f"{ env .problem_id } .jsonl"
218+ self .config ["local_trajectories_dir " ] / f"{ env .problem_id } .jsonl"
174219 )
175220
176- def environment_factory (self , capsule : dict ) -> DataAnalysisEnv :
221+ def environment_factory (self , capsule : dict [str , Any ]) -> DataAnalysisEnv :
222+ """
223+ Create a DataAnalysisEnv instance from a capsule.
224+
225+ Args:
226+ capsule: Dictionary containing capsule information
227+
228+ Returns:
229+ DataAnalysisEnv: Initialized environment
230+ """
177231 raw_questions = ast .literal_eval (capsule ["questions" ])
178232 processed_questions = [
179233 load_mcq (i , open_question = True , question_id = i ["id" ]) for i in raw_questions
@@ -194,7 +248,7 @@ def environment_factory(self, capsule: dict) -> DataAnalysisEnv:
194248 if item .is_file ():
195249 shutil .copy2 (item , work_dir )
196250 elif item .is_dir ():
197- shutil .copytree (item , work_dir / item .name )
251+ shutil .copytree (item , work_dir / item .name , dirs_exist_ok = True )
198252 nb_path = work_dir / self .config ["notebook_name" ]
199253
200254 # Add some extra metadata from config
@@ -217,27 +271,108 @@ def environment_factory(self, capsule: dict) -> DataAnalysisEnv:
217271
218272 return DataAnalysisEnv (** env_args )
219273
274+ async def custom_rollout (
275+ self , agent : Agent , environment : DataAnalysisEnv
276+ ) -> Trajectory :
277+ """
278+ Custom implementation of rollout logic.
279+
280+ Args:
281+ agent: The agent to use for rollout
282+ environment: The environment to run the agent in
283+
284+ Returns:
285+ Trajectory: The generated trajectory
286+
287+ Raises:
288+ NotImplementedError: This method needs to be implemented by subclasses
289+ """
290+ raise NotImplementedError ("Custom rollout not implemented" )
291+
292+ async def vanilla_rollout (
293+ self , agent : Agent , environment : DataAnalysisEnv
294+ ) -> tuple [Trajectory , DataAnalysisEnv ]:
295+ """
296+ Standard implementation of rollout logic.
297+
298+ Args:
299+ agent: The agent to use for rollout
300+ environment: The environment to run the agent in
301+
302+ Returns:
303+ Tuple[Trajectory, DataAnalysisEnv]: The generated trajectory and updated environment
304+ """
305+ obs , tools = await environment .reset ()
306+ agent_state = await agent .init_state (tools )
307+ trajectory = Trajectory ()
308+
309+ for timestep in range (self .config ["max_rollout_steps" ]):
310+ action , next_agent_state , value = await agent .get_asv (agent_state , obs )
311+ next_obs , reward , done , trunc = await environment .step (action .value )
312+ trajectory .steps .append (
313+ Transition (
314+ timestep = timestep ,
315+ agent_state = agent_state ,
316+ next_agent_state = next_agent_state ,
317+ observation = obs ,
318+ next_observation = next_obs ,
319+ action = action ,
320+ reward = reward ,
321+ done = done ,
322+ truncated = trunc ,
323+ value = value ,
324+ )
325+ )
326+ if done or trunc :
327+ break
328+
329+ agent_state = next_agent_state
330+ obs = next_obs
331+
332+ return trajectory , environment
333+
334+ async def batch_rollout (
335+ self , list_of_environments : list [DataAnalysisEnv ]
336+ ) -> list [Trajectory | tuple [Trajectory , DataAnalysisEnv ]]:
337+ """
338+ Run rollouts for a batch of environments.
339+
340+ Args:
341+ list_of_environments: List of environments to run rollouts in
342+
343+ Returns:
344+ List[Union[Trajectory, Tuple[Trajectory, DataAnalysisEnv]]]: List of trajectories or
345+ trajectory-environment pairs depending on rollout type
346+ """
347+ if self .config ["rollout_type" ] == "aviary" :
348+ agent = self .config ["agent_config" ].construct_agent ()
349+ rollout = RolloutManager (agent = agent )
350+ return await rollout .sample_trajectories (
351+ environments = list_of_environments ,
352+ max_steps = self .config ["max_rollout_steps" ],
353+ )
354+
355+ agent = self .config ["agent_config" ].construct_agent ()
356+ rollout_manager = getattr (self , f"{ self .config ['rollout_type' ]} _rollout" )
357+
358+ return await asyncio .gather (* [
359+ rollout_manager (agent , environment ) for environment in list_of_environments
360+ ])
361+
220362 async def run (self ) -> None :
363+ """Run the full trajectory generation pipeline."""
221364 bixbench = await self .load_bixbench ()
222- # Construct agent and rollout manager
223- agent = self .config ["agent_config" ].construct_agent ()
224- rollout = RolloutManager (agent = agent )
225365
226366 # Process environments in batches
227367 for i in range (0 , len (bixbench ), self .config ["batch_size" ]):
228368 batch = bixbench [i : i + self .config ["batch_size" ]]
229369 environments = [self .environment_factory (capsule ) for capsule in batch ]
230-
231- # TODO: Create simple rollout manager that does not use LDP
232- trajectories = await rollout .sample_trajectories (
233- environments = environments , max_steps = self .config ["max_rollout_steps" ]
234- )
235-
370+ results = await self .batch_rollout (environments )
236371 # Store trajectories for each environment
237- for trajectory , env in zip ( trajectories , environments , strict = True ) :
372+ for trajectory , env in results :
238373 await self .store_trajectory (trajectory , env )
239374
240375
241376if __name__ == "__main__" :
242- generator = TraceGenerator ()
377+ generator = TrajectoryGenerator ()
243378 asyncio .run (generator .run ())
0 commit comments