Skip to content

Commit f65840e

Browse files
committed
Large refactor + added docstrings and typing throughout
1 parent 64249a9 commit f65840e

File tree

11 files changed

+343
-135
lines changed

11 files changed

+343
-135
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ ipython_config.py
1111
*.pyc
1212
__pycache__/
1313

14+
# Results
15+
bixbench_results/
1416

1517
.DS_Store
1618
# pyenv

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ This will:
8585
2. Preprocess each capsule in the dataset
8686
3. Generate and store trajectories including the final agent answer and Jupyter notebook in the directory specified in `config.yaml`
8787

88+
Trajectories are saved in the `bixbench_results/` directory as json files.
8889

8990
### Customization
9091

@@ -116,7 +117,7 @@ This script will:
116117
4. Compare model performance across different run groups defined in `config.py`
117118
5. Generate visualizations
118119

119-
Trajectories are saved in the `bixbench_results/` directory as json files.
120+
The script will save the evaluation dataframe as a CSV file in the `bixbench_results/` directory as well as the plots.
120121

121122
## Zero-shot Evaluations & Grading
122123

bixbench/config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ agent:
77
rollout:
88
max_steps: 25
99
batch_size: 16
10+
rollout_type: "aviary"
1011

1112
notebook:
1213
name: "notebook.ipynb"
@@ -25,7 +26,7 @@ capsule:
2526

2627
paths:
2728
workspace_dir: "data/workspace"
28-
traces_dir: "data/traces"
29+
trajectories_dir: "data/trajectories"
2930
data_folder: "data/capsules"
3031
hf_repo_id: "futurehouse/bixbench-internal"
3132

Lines changed: 161 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import shutil
66
from pathlib import Path
7+
from typing import Any
78

89
import datasets
910
import yaml
@@ -12,23 +13,38 @@
1213
from fhda.data_analysis_env import DataAnalysisEnv
1314
from fhda.utils import NBLanguage, collect_notebook_stats, load_mcq
1415
from huggingface_hub import hf_hub_download
15-
from ldp.agent import AgentConfig
16+
from ldp.agent import Agent, AgentConfig
1617
from ldp.alg.rollout import RolloutManager
17-
from ldp.data_structures import Trajectory
18+
from ldp.data_structures import Trajectory, Transition
1819

1920
logger = logging.getLogger(__name__)
2021

2122

22-
class TraceGenerator:
23-
def __init__(self):
23+
class TrajectoryGenerator:
24+
"""
25+
Generator for creating and storing agent trajectories on data analysis tasks.
26+
27+
This class handles the full pipeline of loading benchmark capsules, setting up
28+
environments, running agents through these environments, and storing the resulting
29+
trajectories.
30+
"""
31+
32+
def __init__(self) -> None:
33+
"""Initialize the TrajectoryGenerator with config and create necessary directories."""
2434
self.config = self.load_config()
2535
# Create directories
2636
self.config["local_workspace_dir"].mkdir(parents=True, exist_ok=True)
27-
self.config["local_traces_dir"].mkdir(parents=True, exist_ok=True)
37+
self.config["local_trajectories_dir"].mkdir(parents=True, exist_ok=True)
2838
self.config["local_data_folder"].mkdir(parents=True, exist_ok=True)
2939

3040
# TODO: Move to utils and use a yaml loader package
31-
def load_config(self):
41+
def load_config(self) -> dict[str, Any]:
42+
"""
43+
Load and process configuration from the config.yaml file.
44+
45+
Returns:
46+
Dict[str, Any]: Processed configuration dictionary
47+
"""
3248
config_path = Path(__file__).parent / "config.yaml"
3349
with open(config_path, encoding="utf-8") as f:
3450
config = yaml.safe_load(f)
@@ -60,12 +76,22 @@ def load_config(self):
6076
"base_prompt": base_prompt,
6177
"eval_mode": EvalAnswerMode[config["capsule"]["eval_mode"]],
6278
"local_workspace_dir": Path(config["paths"]["workspace_dir"]).absolute(),
63-
"local_traces_dir": Path(config["paths"]["traces_dir"]).absolute(),
79+
"local_trajectories_dir": Path(
80+
config["paths"]["trajectories_dir"]
81+
).absolute(),
6482
"local_data_folder": Path(config["paths"]["data_folder"]).absolute(),
6583
"hf_repo_id": config["paths"]["hf_repo_id"],
84+
"rollout_type": config["rollout"].get("type", "vanilla"),
85+
"avoid_images": config["capsule"]["avoid_images"],
6686
}
6787

68-
async def process_capsule(self, capsule):
88+
async def process_capsule(self, capsule: dict[str, Any]) -> None:
89+
"""
90+
Process a single benchmark capsule by downloading and extracting necessary files.
91+
92+
Args:
93+
capsule: Dictionary containing capsule information
94+
"""
6995
zip_filename = capsule["data_folder"]
7096
extract_dir = self.config["local_data_folder"] / zip_filename.replace(
7197
".zip", ""
@@ -92,7 +118,13 @@ async def process_capsule(self, capsule):
92118
await asyncio.to_thread(self._extract_and_process_files, zip_path, extract_dir)
93119
capsule["local_data_folder"] = extract_dir
94120

95-
async def load_bixbench(self) -> datasets.Dataset:
121+
async def load_bixbench(self) -> list[dict[str, Any]]:
122+
"""
123+
Load BixBench dataset and process all capsules.
124+
125+
Returns:
126+
List[Dict[str, Any]]: List of processed benchmark capsules
127+
"""
96128
bixbench = datasets.load_dataset(
97129
self.config["hf_repo_id"], split="train"
98130
).to_list()
@@ -103,8 +135,14 @@ async def load_bixbench(self) -> datasets.Dataset:
103135

104136
return bixbench
105137

106-
def _extract_and_process_files(self, zip_path: Path, extract_dir: Path):
107-
"""Helper method to extract and process zip files."""
138+
def _extract_and_process_files(self, zip_path: Path, extract_dir: Path) -> None:
139+
"""
140+
Extract and process zip files for a capsule.
141+
142+
Args:
143+
zip_path: Path to the zip file
144+
extract_dir: Directory to extract files to
145+
"""
108146
# Extract the zip file
109147
shutil.unpack_archive(zip_path, extract_dir)
110148

@@ -140,6 +178,13 @@ def _extract_and_process_files(self, zip_path: Path, extract_dir: Path):
140178
async def store_trajectory(
141179
self, trajectory: Trajectory, env: DataAnalysisEnv
142180
) -> None:
181+
"""
182+
Store trajectory and environment information to disk.
183+
184+
Args:
185+
trajectory: The trajectory to store
186+
env: The environment that generated the trajectory
187+
"""
143188
extract = {
144189
"problem_id": env.problem_id,
145190
"agent_answer": env.state.answer,
@@ -160,7 +205,7 @@ async def store_trajectory(
160205
}
161206

162207
# Download run metadata
163-
with (self.config["local_traces_dir"] / f"{env.problem_id}.json").open(
208+
with (self.config["local_trajectories_dir"] / f"{env.problem_id}.json").open(
164209
"w"
165210
) as f:
166211
json.dump(
@@ -170,10 +215,19 @@ async def store_trajectory(
170215
)
171216
# Download run trajectory
172217
await trajectory.to_jsonl(
173-
self.config["local_traces_dir"] / f"{env.problem_id}.jsonl"
218+
self.config["local_trajectories_dir"] / f"{env.problem_id}.jsonl"
174219
)
175220

176-
def environment_factory(self, capsule: dict) -> DataAnalysisEnv:
221+
def environment_factory(self, capsule: dict[str, Any]) -> DataAnalysisEnv:
222+
"""
223+
Create a DataAnalysisEnv instance from a capsule.
224+
225+
Args:
226+
capsule: Dictionary containing capsule information
227+
228+
Returns:
229+
DataAnalysisEnv: Initialized environment
230+
"""
177231
raw_questions = ast.literal_eval(capsule["questions"])
178232
processed_questions = [
179233
load_mcq(i, open_question=True, question_id=i["id"]) for i in raw_questions
@@ -194,7 +248,7 @@ def environment_factory(self, capsule: dict) -> DataAnalysisEnv:
194248
if item.is_file():
195249
shutil.copy2(item, work_dir)
196250
elif item.is_dir():
197-
shutil.copytree(item, work_dir / item.name)
251+
shutil.copytree(item, work_dir / item.name, dirs_exist_ok=True)
198252
nb_path = work_dir / self.config["notebook_name"]
199253

200254
# Add some extra metadata from config
@@ -217,27 +271,108 @@ def environment_factory(self, capsule: dict) -> DataAnalysisEnv:
217271

218272
return DataAnalysisEnv(**env_args)
219273

274+
async def custom_rollout(
275+
self, agent: Agent, environment: DataAnalysisEnv
276+
) -> Trajectory:
277+
"""
278+
Custom implementation of rollout logic.
279+
280+
Args:
281+
agent: The agent to use for rollout
282+
environment: The environment to run the agent in
283+
284+
Returns:
285+
Trajectory: The generated trajectory
286+
287+
Raises:
288+
NotImplementedError: This method needs to be implemented by subclasses
289+
"""
290+
raise NotImplementedError("Custom rollout not implemented")
291+
292+
async def vanilla_rollout(
293+
self, agent: Agent, environment: DataAnalysisEnv
294+
) -> tuple[Trajectory, DataAnalysisEnv]:
295+
"""
296+
Standard implementation of rollout logic.
297+
298+
Args:
299+
agent: The agent to use for rollout
300+
environment: The environment to run the agent in
301+
302+
Returns:
303+
Tuple[Trajectory, DataAnalysisEnv]: The generated trajectory and updated environment
304+
"""
305+
obs, tools = await environment.reset()
306+
agent_state = await agent.init_state(tools)
307+
trajectory = Trajectory()
308+
309+
for timestep in range(self.config["max_rollout_steps"]):
310+
action, next_agent_state, value = await agent.get_asv(agent_state, obs)
311+
next_obs, reward, done, trunc = await environment.step(action.value)
312+
trajectory.steps.append(
313+
Transition(
314+
timestep=timestep,
315+
agent_state=agent_state,
316+
next_agent_state=next_agent_state,
317+
observation=obs,
318+
next_observation=next_obs,
319+
action=action,
320+
reward=reward,
321+
done=done,
322+
truncated=trunc,
323+
value=value,
324+
)
325+
)
326+
if done or trunc:
327+
break
328+
329+
agent_state = next_agent_state
330+
obs = next_obs
331+
332+
return trajectory, environment
333+
334+
async def batch_rollout(
335+
self, list_of_environments: list[DataAnalysisEnv]
336+
) -> list[Trajectory | tuple[Trajectory, DataAnalysisEnv]]:
337+
"""
338+
Run rollouts for a batch of environments.
339+
340+
Args:
341+
list_of_environments: List of environments to run rollouts in
342+
343+
Returns:
344+
List[Union[Trajectory, Tuple[Trajectory, DataAnalysisEnv]]]: List of trajectories or
345+
trajectory-environment pairs depending on rollout type
346+
"""
347+
if self.config["rollout_type"] == "aviary":
348+
agent = self.config["agent_config"].construct_agent()
349+
rollout = RolloutManager(agent=agent)
350+
return await rollout.sample_trajectories(
351+
environments=list_of_environments,
352+
max_steps=self.config["max_rollout_steps"],
353+
)
354+
355+
agent = self.config["agent_config"].construct_agent()
356+
rollout_manager = getattr(self, f"{self.config['rollout_type']}_rollout")
357+
358+
return await asyncio.gather(*[
359+
rollout_manager(agent, environment) for environment in list_of_environments
360+
])
361+
220362
async def run(self) -> None:
363+
"""Run the full trajectory generation pipeline."""
221364
bixbench = await self.load_bixbench()
222-
# Construct agent and rollout manager
223-
agent = self.config["agent_config"].construct_agent()
224-
rollout = RolloutManager(agent=agent)
225365

226366
# Process environments in batches
227367
for i in range(0, len(bixbench), self.config["batch_size"]):
228368
batch = bixbench[i : i + self.config["batch_size"]]
229369
environments = [self.environment_factory(capsule) for capsule in batch]
230-
231-
# TODO: Create simple rollout manager that does not use LDP
232-
trajectories = await rollout.sample_trajectories(
233-
environments=environments, max_steps=self.config["max_rollout_steps"]
234-
)
235-
370+
results = await self.batch_rollout(environments)
236371
# Store trajectories for each environment
237-
for trajectory, env in zip(trajectories, environments, strict=True):
372+
for trajectory, env in results:
238373
await self.store_trajectory(trajectory, env)
239374

240375

241376
if __name__ == "__main__":
242-
generator = TraceGenerator()
377+
generator = TrajectoryGenerator()
243378
asyncio.run(generator.run())

bixbench/graders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ async def grade_open_ended_answer(question, target, predicted, llm_client):
5353

5454

5555
def compute_metrics(grades: list[bool], is_refued: list[bool]) -> dict:
56-
"""
56+
"""Calculate metrics for question answering evaluation.
57+
5758
Accuracy = (num correct) / (num questions)
5859
precision = (num correct) / ((num questions) - (num unsure)).
5960
"""

0 commit comments

Comments
 (0)