Skip to content

Commit 9633275

Browse files
Move the overlay_action to utils
1 parent 3b07fe9 commit 9633275

File tree

4 files changed

+29
-43
lines changed

4 files changed

+29
-43
lines changed

src/agentlab/agents/agent_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import copy
12
from PIL import Image, ImageDraw
23
from playwright.sync_api import Page
4+
from agentlab.llm.llm_utils import img_to_base_64
5+
from agentlab.analyze import overlay_utils
36

47

58
def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
@@ -128,3 +131,14 @@ def zoom_webpage(page: Page, zoom_factor: float = 1.5):
128131

129132
page.evaluate(f"document.documentElement.style.zoom='{zoom_factor*100}%'")
130133
return page
134+
135+
136+
def overlay_action(obs, action):
137+
"""Overlays actions on screenshot in-place"""
138+
act_img = copy.deepcopy(obs["screenshot"])
139+
act_img = Image.fromarray(act_img)
140+
overlay_utils.annotate_action(act_img, action, properties=obs["extra_element_properties"])
141+
return img_to_base_64(act_img)
142+
143+
144+

src/agentlab/agents/hitl_agent/generic_human_guided_agent.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
SystemMessage,
2525
)
2626
from agentlab.llm.tracking import cost_tracker_decorator
27-
27+
from agentlab.agents.agent_utils import overlay_action
28+
from agentlab.llm.llm_utils import img_to_base_64
2829

2930
class CandidatesGeneration(dp.PromptElement):
3031
# Ask for multiple alternatives; each candidate must contain <think> and <action>.
@@ -119,24 +120,6 @@ def _parse_answer(self, text_answer: str) -> Dict[str, Dict[str, str]]:
119120
return result
120121

121122

122-
def overlay_action(obs, action):
123-
"""Overlays actions on screenshot in-place"""
124-
act_img = copy.deepcopy(obs["screenshot"])
125-
act_img = Image.fromarray(act_img)
126-
overlay_utils.annotate_action(act_img, action, properties=obs["extra_element_properties"])
127-
return img_to_base_64(act_img)
128-
129-
130-
def img_to_base_64(image: Image.Image | np.ndarray) -> str:
131-
"""Converts a PIL Image or NumPy array to a base64-encoded string."""
132-
if isinstance(image, np.ndarray):
133-
image = Image.fromarray(image)
134-
buffer = io.BytesIO()
135-
image.save(buffer, format="PNG")
136-
b64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
137-
return b64_str
138-
139-
140123
@dataclass
141124
class MultipleProposalGenericAgentArgs(GenericAgentArgs):
142125

src/agentlab/agents/hitl_agent/hitl_agent.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1-
import base64
2-
import copy
3-
import io
41
from dataclasses import dataclass
52
from typing import Optional
63

74
import bgym
8-
import numpy as np
95
import playwright
106
from browsergym.experiments.agent import Agent
11-
from PIL import Image
7+
128

139
from agentlab.agents.agent_args import AgentArgs
1410
from agentlab.agents.hitl_agent.base_multi_candidate_agent import MultiCandidateAgent
1511
from agentlab.agents.hitl_agent.hint_labelling import (
1612
HintLabeling,
1713
HintLabelingInputs,
1814
)
19-
from agentlab.analyze import overlay_utils
15+
from agentlab.agents.agent_utils import overlay_action
16+
from agentlab.llm.llm_utils import img_to_base_64
2017
from agentlab.llm.tracking import cost_tracker_decorator
2118

2219

@@ -158,24 +155,6 @@ def set_reproducibility_mode(self):
158155
self.subagent_args.set_reproducibility_mode()
159156

160157

161-
def overlay_action(obs, action):
162-
"""Overlays actions on screenshot in-place"""
163-
act_img = copy.deepcopy(obs["screenshot"])
164-
act_img = Image.fromarray(act_img)
165-
overlay_utils.annotate_action(act_img, action, properties=obs["extra_element_properties"])
166-
return img_to_base_64(act_img)
167-
168-
169-
def img_to_base_64(image: Image.Image | np.ndarray) -> str:
170-
"""Converts a PIL Image or NumPy array to a base64-encoded string."""
171-
if isinstance(image, np.ndarray):
172-
image = Image.fromarray(image)
173-
buffer = io.BytesIO()
174-
image.save(buffer, format="PNG")
175-
b64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
176-
return b64_str
177-
178-
179158
def get_base_human_in_the_loop_genericagent(llm_config):
180159
from agentlab.agents.generic_agent.tmlr_config import BASE_FLAGS
181160
from agentlab.agents.hitl_agent.hitl_agent import HumanInTheLoopAgentArgs

src/agentlab/llm/llm_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,16 @@ def image_to_png_base64_url(image: np.ndarray | Image.Image):
727727
return f"data:image/png;base64,{image_base64}"
728728

729729

730+
def img_to_base_64(image: Image.Image | np.ndarray) -> str:
731+
"""Converts a PIL Image or NumPy array to a base64-encoded string."""
732+
if isinstance(image, np.ndarray):
733+
image = Image.fromarray(image)
734+
buffer = io.BytesIO()
735+
image.save(buffer, format="PNG")
736+
b64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
737+
return b64_str
738+
739+
730740
class BaseMessage(dict):
731741
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
732742
allowed_attrs = {"log_probs"}

0 commit comments

Comments
 (0)