Skip to content

Commit b2c1ac8

Browse files
darglint and black
1 parent ed0f1bd commit b2c1ac8

File tree

5 files changed

+71
-56
lines changed

5 files changed

+71
-56
lines changed

src/agentlab/agents/hilt_agent/base_multi_candidate_agent.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing_extensions import Protocol
2+
23
from agentlab.agents.agent_args import AgentArgs
34

45

@@ -12,9 +13,10 @@ class MultiCandidateAgent(Protocol):
1213

1314
def get_candidate_generations(
1415
self, obs: dict, hint: list[str] | None = None, n_candidates: int = 3
15-
) -> list[dict]:
16+
) -> "list[dict]":
1617
"""
1718
Generate multiple candidate actions for the given observation.
19+
1820
You can pass extra info in agent_info to update internal state of the
1921
agent based on the selected candidate. Your internal state management
2022
should be robust to multiple calls to the get_candidate_generations method
@@ -24,11 +26,6 @@ def get_candidate_generations(
2426
obs: The current observation dictionary containing environment state
2527
hint: Optional list of hint strings to guide candidate generation
2628
n_candidates: Number of candidate actions to generate
27-
28-
Returns:
29-
List of dictionaries, each containing:
30-
- 'action': The candidate action to be executed
31-
- 'agent_info': Additional information about the action generation
3229
"""
3330
...
3431

@@ -37,8 +34,10 @@ def update_agent_state_from_selected_candidate(self, output: dict):
3734
Update the agent's internal state based on the selected candidate.
3835
This can include any memory or planning updates.
3936
37+
Args:
38+
output: The selected candidate action dictionary
4039
"""
41-
...
40+
pass
4241

4342

4443
class MultiCandidateAgentArgs(AgentArgs):
@@ -47,5 +46,5 @@ def make_agent(self) -> MultiCandidateAgent: ...
4746
def __post_init__(self):
4847
"""Prefix subagent name with 'MC-'."""
4948
super().__post_init__()
50-
if hasattr(self, 'agent_name') and self.agent_name:
49+
if hasattr(self, "agent_name") and self.agent_name:
5150
self.agent_name = "MC-" + self.agent_name

src/agentlab/agents/hilt_agent/generic_human_guided_agent.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import bgym
99
import numpy as np
10+
from browsergym.experiments.agent import AgentInfo
1011
from PIL import Image
1112

1213
from agentlab.agents import dynamic_prompting as dp
@@ -23,7 +24,6 @@
2324
SystemMessage,
2425
)
2526
from agentlab.llm.tracking import cost_tracker_decorator
26-
from browsergym.experiments.agent import AgentInfo
2727

2828

2929
class CandidatesGeneration(dp.PromptElement):
@@ -87,15 +87,14 @@ def __init__(self, hint: list[str] | None = None, n_candidates=3) -> None:
8787
)
8888

8989
def _parse_answer(self, text_answer: str) -> Dict[str, Dict[str, str]]:
90-
"""
91-
Extract up to n_candidates candidates, using numbered tags only.
90+
"""Extract up to n_candidates candidates, using numbered tags only.
91+
92+
Args:
93+
text_answer: The text response containing candidate generation tags.
9294
9395
Returns:
94-
{
95-
"candidate_generation_1": {"think": "...", "action": "..."},
96-
"candidate_generation_2": {"think": "...", "action": "..."},
97-
...
98-
}
96+
Dictionary mapping candidate names to their think and action content.
97+
Format: {"candidate_generation_1": {"think": "...", "action": "..."}, ...}
9998
"""
10099
result = {
101100
f"candidate_generation_{i+1}": {"think": "", "action": ""}
@@ -145,11 +144,11 @@ def make_agent(self):
145144
return MultipleProposalGenericAgent(
146145
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
147146
)
148-
147+
149148
def __post_init__(self):
150149
"""Prefix subagent name with 'HILT-'."""
151150
super().__post_init__()
152-
if hasattr(self, 'agent_name') and self.agent_name:
151+
if hasattr(self, "agent_name") and self.agent_name:
153152
self.agent_name = "HILT-" + self.agent_name
154153

155154

@@ -363,13 +362,11 @@ def get_base_agent(llm_config):
363362
agent_configs = [HUMAN_GUIDED_GENERIC_AGENT]
364363
benchmark = bgym.DEFAULT_BENCHMARKS["miniwob"]()
365364
benchmark = benchmark.subset_from_glob("task_name", "*book*")
366-
benchmark.env_args_list = benchmark.env_args_list[2:3]
365+
benchmark.env_args_list = benchmark.env_args_list[3:4]
367366

368367
for env_args in benchmark.env_args_list:
369368
env_args.max_steps = 100 # max human steps
370-
env_args.headless = False
371-
# env_args.use_chat_ui = False
372-
# env_args.use_hint_labeling_ui = True
369+
env_args.headless = True
373370

374371
Study(agent_configs, benchmark, logging_level=logging.WARNING).run(
375372
n_jobs=1,

src/agentlab/agents/hilt_agent/hilt_agent.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66

77
import bgym
88
import numpy as np
9-
from PIL import Image
109
import playwright
10+
from browsergym.experiments.agent import Agent
11+
from PIL import Image
1112

13+
from agentlab.agents.agent_args import AgentArgs
14+
from agentlab.agents.hilt_agent.base_multi_candidate_agent import MultiCandidateAgent
1215
from agentlab.agents.hilt_agent.hint_labelling import (
1316
HintLabeling,
1417
HintLabelingInputs,
1518
)
16-
from agentlab.llm.tracking import cost_tracker_decorator
1719
from agentlab.analyze import overlay_utils
18-
from browsergym.experiments.agent import Agent
19-
from agentlab.agents.agent_args import AgentArgs
20-
from agentlab.agents.hilt_agent.base_multi_candidate_agent import MultiCandidateAgent
20+
from agentlab.llm.tracking import cost_tracker_decorator
21+
2122

2223
class HumanInTheLoopAgent(Agent):
2324

@@ -58,7 +59,7 @@ def get_action(self, obs):
5859
# Generate first candidates
5960
candidates = self.subagent.get_candidate_generations(obs, hint=None, n_candidates=3)
6061
step_n_human_intervention_rounds += 1
61-
suggestions = [{ 'action': c['action'], 'think': c['agent_info'].think} for c in candidates]
62+
suggestions = [{"action": c["action"], "think": c["agent_info"].think} for c in candidates]
6263
# List of Images as base64 - create overlay screenshots for each suggested action
6364
screenshots = [overlay_action(obs, choice["action"]) for choice in suggestions]
6465

@@ -90,11 +91,11 @@ def get_action(self, obs):
9091
hint = response["payload"]["hint"]
9192
step_hint.append(hint)
9293
candidates = self.subagent.get_candidate_generations(
93-
obs,
94-
hint=step_hint if step_hint else None,
95-
n_candidates=3
94+
obs, hint=step_hint if step_hint else None, n_candidates=3
9695
)
97-
suggestions = [{'action': c['action'], 'think': c['agent_info'].think} for c in candidates]
96+
suggestions = [
97+
{"action": c["action"], "think": c["agent_info"].think} for c in candidates
98+
]
9899
screenshots = [overlay_action(obs, choice["action"]) for choice in suggestions]
99100

100101
elif response["type"] == "step":
@@ -135,7 +136,6 @@ def get_action(self, obs):
135136
@dataclass
136137
class HumanInTheLoopAgentArgs(AgentArgs):
137138
subagent_args: Optional[AgentArgs] = None # args for the underlying multiple proposal agent
138-
139139

140140
def make_agent(self):
141141
assert self.subagent_args is not None
@@ -146,15 +146,15 @@ def __post_init__(self):
146146
super().__post_init__()
147147
if self.subagent_args and self.subagent_args.agent_name:
148148
self.agent_name = "HILT-" + self.subagent_args.agent_name
149-
149+
150150
def set_benchmark(self, benchmark, demo_mode):
151151
"""Delegate set_benchmark to the subagent if it has the method."""
152-
if hasattr(self.subagent_args, 'set_benchmark'):
152+
if hasattr(self.subagent_args, "set_benchmark"):
153153
self.subagent_args.set_benchmark(benchmark, demo_mode)
154-
154+
155155
def set_reproducibility_mode(self):
156156
"""Delegate set_reproducibility_mode to the subagent if it has the method."""
157-
if hasattr(self.subagent_args, 'set_reproducibility_mode'):
157+
if hasattr(self.subagent_args, "set_reproducibility_mode"):
158158
self.subagent_args.set_reproducibility_mode()
159159

160160

@@ -175,16 +175,17 @@ def img_to_base_64(image: Image.Image | np.ndarray) -> str:
175175
b64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
176176
return b64_str
177177

178+
178179
def get_base_human_in_the_loop_genericagent(llm_config):
179180
from agentlab.agents.generic_agent.tmlr_config import BASE_FLAGS
180-
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
181181
from agentlab.agents.hilt_agent.hilt_agent import HumanInTheLoopAgentArgs
182182
from agentlab.agents.hilt_agent.multi_candidate_generic_agent import (
183183
MultiCandidateGenericAgentArgs,
184184
)
185+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
185186

186187
return HumanInTheLoopAgentArgs(
187-
subagent_args = MultiCandidateGenericAgentArgs(
188+
subagent_args=MultiCandidateGenericAgentArgs(
188189
chat_model_args=CHAT_MODEL_ARGS_DICT[llm_config],
189190
flags=BASE_FLAGS,
190191
)
@@ -210,7 +211,6 @@ def get_base_human_in_the_loop_genericagent(llm_config):
210211
env_args.max_steps = 100 # max human steps
211212
env_args.headless = False
212213

213-
214214
Study(agent_configs, benchmark, logging_level=logging.WARNING).run(
215215
n_jobs=1,
216216
parallel_backend="sequential",

src/agentlab/agents/hilt_agent/hint_labelling.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from typing import Dict, List, Optional
66

77
import playwright.sync_api
8+
from browsergym.core import _get_global_playwright
89
from pydantic import BaseModel, Field
910

1011
from agentlab.agents.hilt_agent import hint_labelling_ui_files
11-
from browsergym.core import _get_global_playwright
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -109,6 +109,16 @@ def wait_for_response(self, timeout: Optional[float] = 600) -> dict:
109109
"""
110110
Wait until the page makes a request to /api/reprompt or /api/submit,
111111
then parse the request body and return it in your schema.
112+
113+
Args:
114+
timeout (Optional[float]): Maximum time to wait for the request in seconds. If None or 0,
115+
waits indefinitely. Defaults to 600 seconds.
116+
117+
Returns:
118+
dict: A dictionary containing the parsed response with 'type' and 'payload' keys.
119+
For /api/reprompt: {'type': 'reprompt', 'payload': {'hint': str}}
120+
For /api/submit: {'type': 'step', 'payload': {'think': str, 'action': str}}
121+
112122
"""
113123
logger.info("Waiting for response from Hint Labeling UI...")
114124

src/agentlab/agents/hilt_agent/multi_candidate_generic_agent.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from dataclasses import asdict, dataclass
33
from typing import Dict, List
44

5+
from browsergym.experiments.agent import AgentInfo
6+
57
from agentlab.agents import dynamic_prompting as dp
68
from agentlab.agents.generic_agent.generic_agent import GenericAgent, GenericAgentArgs
79
from agentlab.agents.generic_agent.generic_agent_prompt import MainPrompt
810
from agentlab.llm.llm_utils import Discussion, HumanMessage, SystemMessage
9-
from browsergym.experiments.agent import AgentInfo
1011

1112

1213
class CandidatesGeneration(dp.PromptElement):
@@ -70,15 +71,14 @@ def __init__(self, hint: list[str] | None = None, n_candidates=3) -> None:
7071
)
7172

7273
def _parse_answer(self, text_answer: str) -> Dict[str, Dict[str, str]]:
73-
"""
74-
Extract up to n_candidates candidates, using numbered tags only.
74+
"""Extract up to n_candidates candidates, using numbered tags only.
75+
76+
Args:
77+
text_answer: The text response containing candidate generation tags.
7578
7679
Returns:
77-
{
78-
"candidate_generation_1": {"think": "...", "action": "..."},
79-
"candidate_generation_2": {"think": "...", "action": "..."},
80-
...
81-
}
80+
Dictionary mapping candidate names to their think and action content.
81+
Format: {"candidate_generation_1": {"think": "...", "action": "..."}, ...}
8282
"""
8383
result = {
8484
f"candidate_generation_{i+1}": {"think": "", "action": ""}
@@ -123,7 +123,6 @@ def get_candidate_generations(
123123
# Important to handle cases when get_candidate_generation is called multiple times in a single step.
124124
if not self.obs_history or self.obs_history[-1] is not obs:
125125
self.obs_history.append(obs)
126-
127126

128127
main_prompt = MainPrompt(
129128
action_set=self.action_set,
@@ -180,8 +179,12 @@ def get_candidate_generations(
180179
return output
181180

182181
def update_agent_state_from_selected_candidate(self, output):
183-
"""Updates the agent's internal state based on the selected candidate from human feedback."""
184-
action, agent_info = output['action'], output['agent_info']
182+
"""Updates the agent's internal state based on the selected candidate from human feedback.
183+
184+
Args:
185+
output: Dictionary containing 'action' and 'agent_info' keys from selected candidate.
186+
"""
187+
action, agent_info = output["action"], output["agent_info"]
185188
self.plan = agent_info.extra_info.get("plan", self.plan)
186189
self.plan_step = agent_info.extra_info.get("step", self.plan_step)
187190
self.memories.append(agent_info.extra_info.get("memory", None))
@@ -191,11 +194,17 @@ def update_agent_state_from_selected_candidate(self, output):
191194
def get_action(self, obs):
192195
"""Generates multiple candidates and always returns the first one.
193196
This allows to use this agent as a drop-in replacement for a single-candidate agent.
197+
198+
Args:
199+
obs: The observation from the environment.
200+
201+
Returns:
202+
tuple: A tuple containing (action, agent_info).
194203
"""
195-
candidates = self.get_candidate_generations(obs, hint=None, n_candidates=2)
196-
selection = candidates[0] # always select the first option.
204+
candidates = self.get_candidate_generations(obs, hint=None, n_candidates=2)
205+
selection = candidates[0] # always select the first option.
197206
self.update_agent_state_from_selected_candidate(selection)
198-
action, agent_info = selection['action'], selection['agent_info']
207+
action, agent_info = selection["action"], selection["agent_info"]
199208

200209
return action, agent_info
201210

@@ -212,5 +221,5 @@ def make_agent(self):
212221
def __post_init__(self):
213222
"""Prefix subagent name with 'MC-'."""
214223
super().__post_init__()
215-
if hasattr(self, 'agent_name') and self.agent_name:
224+
if hasattr(self, "agent_name") and self.agent_name:
216225
self.agent_name = "MC-" + self.agent_name

0 commit comments

Comments
 (0)