Skip to content

Commit 9d7d718

Browse files
committed
darglinting everything
1 parent b225ef3 commit 9d7d718

File tree

12 files changed

+83
-31
lines changed

12 files changed

+83
-31
lines changed

src/agentlab/agents/agent_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from bgym import AbstractAgentArgs
21
import bgym
2+
from bgym import AbstractAgentArgs
33

44

55
class AgentArgs(AbstractAgentArgs):
@@ -28,6 +28,9 @@ def set_reproducibility_mode(self):
2828
as possible e.g. setting the temperature of the model to 0.
2929
3030
This is only called when reproducibility is requested.
31+
32+
Raises:
33+
NotImplementedError: If the agent does not support reproducibility.
3134
"""
3235
raise NotImplementedError(
3336
f"set_reproducibility_mode is not implemented for agent_args {self.__class__.__name__}"

src/agentlab/agents/generic_agent/agent_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,19 @@
193193
add_missparsed_messages=True,
194194
)
195195

196+
196197
AGENT_8B = GenericAgentArgs(
197198
chat_model_args=CHAT_MODEL_ARGS_DICT["meta-llama/Meta-Llama-3-8B-Instruct"],
198199
flags=FLAGS_8B,
199200
)
200201

201202

203+
AGENT_LLAMA31_8B = GenericAgentArgs(
204+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/meta-llama/llama-3.1-8b-instruct"],
205+
flags=FLAGS_8B,
206+
)
207+
208+
202209
# GPT-4o default config
203210
FLAGS_GPT_4o = GenericPromptFlags(
204211
obs=dp.ObsFlags(

src/agentlab/agents/generic_agent/reproducibility_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def make_repro_agent(agent_args: AgentArgs, exp_dir: Path | str):
199199
agent_args (AgentArgs): The original agent args.
200200
exp_dir (Path | str): The directory where the experiment was saved.
201201
202+
Returns:
203+
ReproAgentArgs: The new agent args.
202204
"""
203205
exp_dir = Path(exp_dir)
204206
assert isinstance(agent_args, GenericAgentArgs)

src/agentlab/agents/visualwebarena/agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def get_action(self, obs: Any) -> tuple[str, dict]:
188188
Replica of VisualWebArena agent
189189
https://github.com/web-arena-x/visualwebarena/blob/89f5af29305c3d1e9f97ce4421462060a70c9a03/agent/prompts/prompt_constructor.py#L211
190190
https://github.com/web-arena-x/visualwebarena/blob/89f5af29305c3d1e9f97ce4421462060a70c9a03/agent/prompts/prompt_constructor.py#L272
191+
192+
Args:
193+
obs (Any): Observation from the environment
194+
195+
Returns:
196+
tuple[str, dict]: Action and AgentInfo
191197
"""
192198
user_messages = []
193199

src/agentlab/analyze/inspect_results.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from IPython.display import display
1717
from tqdm import tqdm
1818

19-
2019
from agentlab.analyze.error_categorization import (
2120
ERR_CLASS_MAP,
2221
is_critical_server_error,
@@ -83,9 +82,7 @@ def set_index_from_variables(
8382
index_black_list: List of wildard patterns to match variables that
8483
should be excluded from the index.
8584
task_key: The key to use as the first level of the index.
86-
force_at_leaste_one_variable: If True, force at least one variable in the
87-
index. If no variable is found, the index will be set to
88-
task_key + "agent.agent_name".
85+
add_agent_and_benchmark: If True, add agent.agent_name and env.benchmark
8986
"""
9087
df.reset_index(inplace=True)
9188
constants, variables, _ = get_constants_and_variables(df)
@@ -141,6 +138,7 @@ def load_result_df(
141138
should be included in the index.
142139
index_black_list: List of wildard patterns to match variables that
143140
should be excluded from the index.
141+
remove_args_suffix: If True, remove the _args suffix from the columns
144142
145143
Returns:
146144
pd.DataFrame: The result dataframe
@@ -777,17 +775,13 @@ def _categorize_error(row):
777775

778776

779777
def _benchmark_from_task_name(task_name: str):
780-
"""Extract the benchmark from the task name.
781-
TODO should be more robost, e.g. handle workarna.L1, workarena.L2, etc.
782-
"""
778+
"""Extract the benchmark from the task name."""
779+
# TODO should be more robost, e.g. handle workarna.L1, workarena.L2, etc.
783780
return task_name.split(".")[0]
784781

785782

786783
def summarize_study(result_df: pd.DataFrame) -> pd.DataFrame:
787-
"""Create a summary of the study.
788-
789-
Similar to global report, but handles single agent differently.
790-
"""
784+
"""Create a summary of the study. Similar to global report, but handles single agent differently."""
791785

792786
levels = list(range(result_df.index.nlevels))
793787
return result_df.groupby(level=levels[1:]).apply(summarize)

src/agentlab/experiments/exp_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def add_dependencies(exp_args_list: list[ExpArgs], task_dependencies: dict[str,
9595
Returns:
9696
list[ExpArgs]
9797
The modified exp_args_list with dependencies added.
98+
99+
Raises:
100+
ValueError: If the task_dependencies are not valid.
98101
"""
99102

100103
if task_dependencies is None or all([len(dep) == 0 for dep in task_dependencies.values()]):

src/agentlab/experiments/graph_execution_ray.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter
4747
4848
I tried various different methods for killing a job that hangs. so far it's
4949
the only one that seems to work reliably (hopefully)
50+
51+
Args:
52+
tasks: dict[str, ray.ObjectRef]
53+
Dictionary of task_id: task_ref
54+
timeout: float
55+
Timeout in seconds
56+
poll_interval: float
57+
Polling interval in seconds
58+
59+
Returns:
60+
dict[str, Any]: Dictionary of task_id: result
5061
"""
5162
task_list = list(tasks.values())
5263
task_ids = list(tasks.keys())

src/agentlab/experiments/launch_exp.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import bgym
66
from browsergym.experiments.loop import ExpArgs, yield_all_exp_results
7+
78
from agentlab.experiments.exp_utils import run_exp
89

910

@@ -24,13 +25,16 @@ def run_experiments(
2425
Number of parallel jobs.
2526
exp_args_list: list[ExpArgs]
2627
List of ExpArgs objects.
27-
exp_dir: Path
28+
study_dir: Path
2829
Directory where the experiments will be saved.
2930
parallel_backend: str
3031
Parallel backend to use. Either "joblib", "ray" or "sequential".
3132
The only backend that supports webarena graph dependencies correctly is ray or sequential.
3233
avg_step_timeout: int
3334
Will raise a TimeoutError if the episode is not finished after env_args.max_steps * avg_step_timeout seconds.
35+
36+
Raises:
37+
ValueError: If the parallel_backend is not recognized.
3438
"""
3539

3640
if len(exp_args_list) == 0:
@@ -110,6 +114,13 @@ def find_incomplete(study_dir: str | Path, include_errors=True):
110114
Find all incomplete experiments and relaunch them.
111115
- "incomplete_only": relaunch only the incomplete experiments.
112116
- "incomplete_or_error": relaunch incomplete or errors.
117+
118+
Returns:
119+
list[ExpArgs]
120+
List of ExpArgs objects to relaunch.
121+
122+
Raises:
123+
ValueError: If the study_dir does not exist.
113124
"""
114125
study_dir = Path(study_dir)
115126

@@ -152,6 +163,16 @@ def _hide_completed(exp_result: bgym.ExpResult, include_errors: bool = True):
152163
153164
This little hack, allows an elegant way to keep the task dependencies for e.g. webarena
154165
while skipping the tasks that are completed when relaunching.
166+
167+
Args:
168+
exp_result: bgym.ExpResult
169+
The experiment result to hide.
170+
include_errors: bool
171+
If True, include experiments that errored.
172+
173+
Returns:
174+
ExpArgs
175+
The ExpArgs object hidden if the experiment is completed.
155176
"""
156177

157178
hide = False

src/agentlab/experiments/reproducibility_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _get_git_username(repo: Repo) -> str:
5858
5. Environment variables (GIT_AUTHOR_NAME and GIT_COMMITTER_NAME)
5959
6060
Args:
61-
repo (git.Repo): A GitPython Repo object representing the Git repository.
61+
repo (Repo): A GitPython Repo object representing the Git repository.
6262
6363
Returns:
6464
str: The first non-None username found, or None if no username is found.

src/agentlab/experiments/study.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from abc import ABC, abstractmethod
21
import gzip
32
import logging
43
import pickle
@@ -18,7 +17,6 @@
1817
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
1918
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
2019

21-
2220
logger = logging.getLogger(__name__)
2321

2422

@@ -186,7 +184,15 @@ def load_exp_args_list(self):
186184
def set_reproducibility_info(self, strict_reproducibility=False, comment=None):
187185
"""Gather relevant information that may affect the reproducibility of the experiment
188186
189-
e.g.: versions of BrowserGym, benchmark, AgentLab..."""
187+
e.g.: versions of BrowserGym, benchmark, AgentLab...
188+
189+
Args:
190+
strict_reproducibility: bool
191+
If True, all modifications have to be committed before running the experiments.
192+
Also, if relaunching a study, it will not be possible if the code has changed.
193+
comment: str
194+
Extra comment to add to the reproducibility information.
195+
"""
190196
agent_names = [a.agent_name for a in self.agent_args]
191197
info = repro.get_reproducibility_info(
192198
agent_names,
@@ -252,13 +258,14 @@ def _run(self, n_jobs=1, parallel_backend="joblib", strict_reproducibility=False
252258
Args:
253259
n_jobs: int
254260
Number of parallel jobs.
255-
256261
parallel_backend: str
257262
Parallel backend to use. Either "joblib", "dask" or "sequential".
258-
259263
strict_reproducibility: bool
260264
If True, all modifications have to be committed before running the experiments.
261265
Also, if relaunching a study, it will not be possible if the code has changed.
266+
267+
Raises:
268+
ValueError: If the exp_args_list is None.
262269
"""
263270

264271
if self.exp_args_list is None:
@@ -276,10 +283,6 @@ def append_to_journal(self, strict_reproducibility=True):
276283
Args:
277284
strict_reproducibility: bool
278285
If True, incomplete experiments will raise an error.
279-
280-
Raises:
281-
ValueError: If the reproducibility information is not compatible
282-
with the report.
283286
"""
284287
_, summary_df, _ = self.get_results()
285288
repro.append_to_journal(
@@ -447,9 +450,16 @@ def _agents_on_benchmark(
447450
If True, the experiments will be run in demo mode.
448451
logging_level: int
449452
The logging level for individual jobs.
453+
logging_level_stdout: int
454+
The logging level for the stdout.
455+
ignore_dependencies: bool
456+
If True, the dependencies will be ignored and all experiments can be run in parallel.
450457
451458
Returns:
452459
list[ExpArgs]: The list of experiments to run.
460+
461+
Raises:
462+
ValueError: If multiple agents are run on a benchmark that requires manual reset.
453463
"""
454464

455465
if not isinstance(agents, (list, tuple)):

0 commit comments

Comments
 (0)