1+ import base64
12import gzip
23import importlib .metadata
34import json
1314from collections import defaultdict
1415from dataclasses import asdict , dataclass , field , is_dataclass
1516from datetime import datetime
17+ from io import BytesIO
1618from pathlib import Path
1719from typing import Optional
1820
1921import gymnasium as gym
2022import numpy as np
23+ import PIL .Image
2124from browsergym .core .chat import Chat
25+ from browsergym .core .hint_labeling import HintLabeling , HintLabelingInputs
2226from browsergym .experiments .agent import Agent
2327from browsergym .experiments .utils import count_tokens
2428from dataclasses_json import DataClassJsonMixin
@@ -404,7 +408,7 @@ def _make_dir(self, exp_root):
404408 def run (self ):
405409 """Run the experiment and save the results"""
406410 # start writing logs to run logfile
407- self ._set_logger ()
411+ # self._set_logger()
408412
409413 # log python environment info
410414 save_package_versions (Path (self .exp_dir ))
@@ -443,15 +447,18 @@ def run(self):
443447 # will end the episode after saving the step info.
444448 step_info .truncated = True
445449
446- step_info .save_step_info (
447- self .exp_dir , save_screenshot = self .save_screenshot , save_som = self .save_som
448- )
449- logger .debug ("Step info saved." )
450+ # step_info.save_step_info(
451+ # self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
452+ # )
453+ # logger.debug("Step info saved.")
450454
451455 if hasattr (env .unwrapped , "chat" ) and isinstance (env .unwrapped .chat , Chat ):
452456 _send_chat_info (env .unwrapped .chat , action , step_info .agent_info )
453457 logger .debug ("Chat info sent." )
454458
459+ if hasattr (env .unwrapped , "hint_labeling" ) and isinstance (env .unwrapped .hint_labeling , HintLabeling ):
460+ _update_hint_labeling (env .unwrapped .hint_labeling , action , agent , step_info )
461+
455462 if action is None :
456463 logger .debug ("Agent returned None action. Ending episode." )
457464 break
@@ -481,10 +488,11 @@ def run(self):
481488
482489 finally :
483490 try :
484- if step_info is not None :
485- step_info .save_step_info (
486- self .exp_dir , save_screenshot = self .save_screenshot , save_som = self .save_som
487- )
491+ pass
492+ # if step_info is not None:
493+ # step_info.save_step_info(
494+ # self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
495+ # )
488496 except Exception as e :
489497 logger .error (f"Error while saving step info in the finally block: { e } " )
490498 try :
@@ -508,7 +516,8 @@ def run(self):
508516 except Exception as e :
509517 logger .exception (f"Error while closing the environment: { e } " )
510518 try :
511- self ._unset_logger () # stop writing logs to run logfile
519+ # self._unset_logger() # stop writing logs to run logfile
520+ pass
512521 except Exception as e :
513522 logger .exception (f"Error while unsetting the logger: { e } " )
514523
@@ -942,6 +951,76 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
942951 logger .info (msg )
943952 chat .add_message (role = "info" , msg = msg )
944953
954+ def _convert_np_array_to_base64 (np_array : np .ndarray ):
955+ im = PIL .Image .fromarray (np_array )
956+ buffered = BytesIO ()
957+ im .save (buffered , format = "PNG" )
958+ img_b64 = base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
959+ return img_b64
960+
961+ def _update_hint_labeling (hint_labeling : HintLabeling , action : str , agent : Agent , step_info : StepInfo ):
962+ """Update the hint labeling with the action and agent info."""
963+ context = HintLabelingInputs (
964+ goal = step_info .obs .get ("goal" , "" ), # TODO: is this goal deprecated?
965+ error_feedback = step_info .obs .get ("last_action_error" , "" ),
966+ screenshot = _convert_np_array_to_base64 (step_info .obs ["screenshot" ]),
967+ axtree = step_info .obs ["axtree_txt" ],
968+ history = [], # TODO: add history
969+ hint = "" ,
970+ suggestions = [
971+ {
972+ "id" : "1" ,
973+ "action" : action ,
974+ "think" : step_info .agent_info .think ,
975+ },
976+ {
977+ "id" : "2" ,
978+ "action" : "test" ,
979+ "think" : "test" ,
980+ }
981+ ]
982+ )
983+ while True :
984+ # update hint labeling ui context
985+ logger .info ("Updating Hint Labeling UI context..." )
986+ hint_labeling .update_context (context )
987+
988+ # wait for hint labeling response
989+ logger .info ("Waiting for Hint Labeling UI response..." )
990+ response = hint_labeling .wait_for_response ()
991+
992+ # if payload is for reprompt, we ask for 5 suggestions and we combine everything
993+ if response ["type" ] == "reprompt" :
994+ # reprompt model 5 times
995+ hint = response ["payload" ]["hint" ]
996+ agent .flags .extra_instructions = hint
997+ suggestions = []
998+ for i in tqdm (range (5 )):
999+ # TODO: make this more optimal
1000+ action = step_info .from_action (agent )
1001+ think = step_info .agent_info .think
1002+ suggestions .append ({"id" : str (i + 1 ), "action" : action , "think" : think })
1003+
1004+ # update context
1005+ context = HintLabelingInputs (
1006+ goal = "blablabli" ,
1007+ error_feedback = context .error_feedback ,
1008+ screenshot = context .screenshot ,
1009+ axtree = context .axtree ,
1010+ history = context .history ,
1011+ hint = hint ,
1012+ suggestions = suggestions
1013+ )
1014+ continue
1015+
1016+ # otherwise, if payload is for action, we return the updated action and save the hint
1017+ elif response ["type" ] == "step" :
1018+ step_info .agent_info .think = response ["payload" ]["think" ]
1019+ action = response ["payload" ]["action" ]
1020+ return action
1021+ else :
1022+ raise ValueError (f"Unknown response type: { response ['type' ]} " )
1023+
9451024
9461025def _flatten_dict (d , parent_key = "" , sep = "." ):
9471026 """Recursively flatten a nested dictionary."""
0 commit comments