Skip to content

Commit 90c2468

Browse files
add hint labeling ui support in loop.py
1 parent 571b3f5 commit 90c2468

File tree

1 file changed

+89
-10
lines changed

1 file changed

+89
-10
lines changed

src/agentlab/experiments/loop.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import gzip
23
import importlib.metadata
34
import json
@@ -13,12 +14,15 @@
1314
from collections import defaultdict
1415
from dataclasses import asdict, dataclass, field, is_dataclass
1516
from datetime import datetime
17+
from io import BytesIO
1618
from pathlib import Path
1719
from typing import Optional
1820

1921
import gymnasium as gym
2022
import numpy as np
23+
import PIL.Image
2124
from browsergym.core.chat import Chat
25+
from browsergym.core.hint_labeling import HintLabeling, HintLabelingInputs
2226
from browsergym.experiments.agent import Agent
2327
from browsergym.experiments.utils import count_tokens
2428
from dataclasses_json import DataClassJsonMixin
@@ -404,7 +408,7 @@ def _make_dir(self, exp_root):
404408
def run(self):
405409
"""Run the experiment and save the results"""
406410
# start writing logs to run logfile
407-
self._set_logger()
411+
# self._set_logger()
408412

409413
# log python environment info
410414
save_package_versions(Path(self.exp_dir))
@@ -443,15 +447,18 @@ def run(self):
443447
# will end the episode after saving the step info.
444448
step_info.truncated = True
445449

446-
step_info.save_step_info(
447-
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
448-
)
449-
logger.debug("Step info saved.")
450+
# step_info.save_step_info(
451+
# self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
452+
# )
453+
# logger.debug("Step info saved.")
450454

451455
if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat):
452456
_send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
453457
logger.debug("Chat info sent.")
454458

459+
if hasattr(env.unwrapped, "hint_labeling") and isinstance(env.unwrapped.hint_labeling, HintLabeling):
460+
_update_hint_labeling(env.unwrapped.hint_labeling, action, agent, step_info)
461+
455462
if action is None:
456463
logger.debug("Agent returned None action. Ending episode.")
457464
break
@@ -481,10 +488,11 @@ def run(self):
481488

482489
finally:
483490
try:
484-
if step_info is not None:
485-
step_info.save_step_info(
486-
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
487-
)
491+
pass
492+
# if step_info is not None:
493+
# step_info.save_step_info(
494+
# self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
495+
# )
488496
except Exception as e:
489497
logger.error(f"Error while saving step info in the finally block: {e}")
490498
try:
@@ -508,7 +516,8 @@ def run(self):
508516
except Exception as e:
509517
logger.exception(f"Error while closing the environment: {e}")
510518
try:
511-
self._unset_logger() # stop writing logs to run logfile
519+
# self._unset_logger() # stop writing logs to run logfile
520+
pass
512521
except Exception as e:
513522
logger.exception(f"Error while unsetting the logger: {e}")
514523

@@ -942,6 +951,76 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
942951
logger.info(msg)
943952
chat.add_message(role="info", msg=msg)
944953

954+
def _convert_np_array_to_base64(np_array: np.ndarray):
955+
im = PIL.Image.fromarray(np_array)
956+
buffered = BytesIO()
957+
im.save(buffered, format="PNG")
958+
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
959+
return img_b64
960+
961+
def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent, step_info: StepInfo):
962+
"""Update the hint labeling with the action and agent info."""
963+
context = HintLabelingInputs(
964+
goal=step_info.obs.get("goal", ""), # TODO: is this goal deprecated?
965+
error_feedback=step_info.obs.get("last_action_error", ""),
966+
screenshot = _convert_np_array_to_base64(step_info.obs["screenshot"]),
967+
axtree = step_info.obs["axtree_txt"],
968+
history = [], # TODO: add history
969+
hint = "",
970+
suggestions = [
971+
{
972+
"id": "1",
973+
"action": action,
974+
"think": step_info.agent_info.think,
975+
},
976+
{
977+
"id": "2",
978+
"action": "test",
979+
"think": "test",
980+
}
981+
]
982+
)
983+
while True:
984+
# update hint labeling ui context
985+
logger.info("Updating Hint Labeling UI context...")
986+
hint_labeling.update_context(context)
987+
988+
# wait for hint labeling response
989+
logger.info("Waiting for Hint Labeling UI response...")
990+
response = hint_labeling.wait_for_response()
991+
992+
# if payload is for reprompt, we ask for 5 suggestions and we combine everything
993+
if response["type"] == "reprompt":
994+
# reprompt model 5 times
995+
hint = response["payload"]["hint"]
996+
agent.flags.extra_instructions = hint
997+
suggestions = []
998+
for i in tqdm(range(5)):
999+
# TODO: make this more optimal
1000+
action = step_info.from_action(agent)
1001+
think = step_info.agent_info.think
1002+
suggestions.append({"id": str(i+1), "action": action, "think": think})
1003+
1004+
# update context
1005+
context = HintLabelingInputs(
1006+
goal="blablabli",
1007+
error_feedback=context.error_feedback,
1008+
screenshot = context.screenshot,
1009+
axtree = context.axtree,
1010+
history = context.history,
1011+
hint = hint,
1012+
suggestions = suggestions
1013+
)
1014+
continue
1015+
1016+
# otherwise, if payload is for action, we return the updated action and save the hint
1017+
elif response["type"] == "step":
1018+
step_info.agent_info.think = response["payload"]["think"]
1019+
action = response["payload"]["action"]
1020+
return action
1021+
else:
1022+
raise ValueError(f"Unknown response type: {response['type']}")
1023+
9451024

9461025
def _flatten_dict(d, parent_key="", sep="."):
9471026
"""Recursively flatten a nested dictionary."""

0 commit comments

Comments
 (0)