Skip to content

Commit 4279d5c

Browse files
recursixgasse
authored andcommitted
fix tests
1 parent 5297157 commit 4279d5c

File tree

4 files changed

+27
-15
lines changed

4 files changed

+27
-15
lines changed

src/agentlab/agents/dynamic_prompting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import bgym
1212
from browsergym.core.action.base import AbstractActionSet
13-
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html
13+
from browsergym.utils.obs import (
14+
flatten_axtree_to_str,
15+
flatten_dom_to_str,
16+
overlay_som,
17+
prune_html,
18+
)
1419

1520
from agentlab.llm.llm_utils import (
1621
BaseMessage,
@@ -385,7 +390,7 @@ def _prompt(self) -> str:
385390
URL: {page_url}
386391
"""
387392
prompt_pieces.append(prompt_piece)
388-
return "\n".join(prompt_pieces)
393+
return "\n".join(prompt_pieces)
389394

390395

391396
class Observation(Shrinkable):

src/agentlab/experiments/graph_execution_ray.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
# # Disable Ray log deduplication
44
# os.environ["RAY_DEDUP_LOGS"] = "0"
5+
import logging
56
import time
6-
import ray
7+
78
import bgym
8-
from agentlab.experiments.exp_utils import run_exp, _episode_timeout
9+
import ray
910
from ray.util import state
10-
import logging
11+
12+
from agentlab.experiments.exp_utils import _episode_timeout, run_exp
1113

1214
logger = logging.getLogger(__name__)
1315

@@ -36,7 +38,7 @@ def get_task(exp_arg: bgym.ExpArgs):
3638
get_task(exp_arg)
3739

3840
max_timeout = max([_episode_timeout(exp_args, avg_step_timeout) for exp_args in exp_args_list])
39-
41+
4042
return poll_for_timeout(task_map, max_timeout, poll_interval=max_timeout * 0.1)
4143

4244

src/agentlab/experiments/study.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def get_report(self, ignore_cache=False, ignore_stale=False):
254254
return inspect_results.get_study_summary(
255255
self.dir, ignore_cache=ignore_cache, ignore_stale=ignore_stale
256256
)
257-
257+
258258
def override_max_steps(self, max_steps):
259259
for exp_args in self.exp_args_list:
260260
exp_args.env_args.max_steps = max_steps

tests/agents/test_generic_prompt.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from agentlab.agents import dynamic_prompting as dp
77
from agentlab.agents.generic_agent.agent_configs import FLAGS_GPT_3_5
8-
from agentlab.agents.generic_agent.generic_agent_prompt import GenericPromptFlags, MainPrompt
8+
from agentlab.agents.generic_agent.generic_agent_prompt import (
9+
GenericPromptFlags,
10+
MainPrompt,
11+
)
912
from agentlab.llm.llm_utils import count_tokens
1013

1114
html_template = """
@@ -32,15 +35,18 @@
3235
}
3336

3437
OBS_HISTORY = [
35-
base_obs | {
38+
base_obs
39+
| {
3640
"pruned_html": html_template.format(1),
3741
"last_action_error": "",
3842
},
39-
base_obs | {
43+
base_obs
44+
| {
4045
"pruned_html": html_template.format(2),
4146
"last_action_error": "Hey, this is an error in the past",
4247
},
43-
base_obs | {
48+
base_obs
49+
| {
4450
"pruned_html": html_template.format(3),
4551
"last_action_error": "Hey, there is an error now",
4652
},
@@ -102,7 +108,7 @@
102108
),
103109
(
104110
"obs.use_tabs",
105-
("Currently open tabs:","(active tab)"),
111+
("Currently open tabs:", "(active tab)"),
106112
),
107113
(
108114
"obs.use_focused_element",
@@ -165,7 +171,7 @@ def test_shrinking_observation():
165171
flags.obs.use_html = True
166172

167173
prompt_maker = MainPrompt(
168-
action_set=dp.HighLevelActionSet(),
174+
action_set=bgym.HighLevelActionSet(),
169175
obs_history=OBS_HISTORY,
170176
actions=ACTIONS,
171177
memories=MEMORIES,
@@ -231,7 +237,7 @@ def test_main_prompt_elements_present():
231237
# Initialize MainPrompt
232238
prompt = str(
233239
MainPrompt(
234-
action_set=dp.HighLevelActionSet(),
240+
action_set=bgym.HighLevelActionSet(),
235241
obs_history=OBS_HISTORY,
236242
actions=ACTIONS,
237243
memories=MEMORIES,
@@ -253,4 +259,3 @@ def test_main_prompt_elements_present():
253259
test_main_prompt_elements_present()
254260
# for flag, expected_prompts in FLAG_EXPECTED_PROMPT:
255261
# test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
256-

0 commit comments

Comments
 (0)