@@ -460,8 +460,12 @@ def run(self):
460460 _send_chat_info (env .unwrapped .chat , action , step_info .agent_info )
461461 logger .debug ("Chat info sent." )
462462
463- if hasattr (env .unwrapped , "hint_labeling" ) and isinstance (env .unwrapped .hint_labeling , HintLabeling ):
464- action = _update_hint_labeling (env .unwrapped .hint_labeling , action , agent , step_info )
463+ if hasattr (env .unwrapped , "hint_labeling" ) and isinstance (
464+ env .unwrapped .hint_labeling , HintLabeling
465+ ):
466+ action = _update_hint_labeling (
467+ env .unwrapped .hint_labeling , action , agent , step_info
468+ )
465469
466470 if action is None :
467471 logger .debug ("Agent returned None action. Ending episode." )
@@ -955,29 +959,33 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
955959 logger .info (msg )
956960 chat .add_message (role = "info" , msg = msg )
957961
962+
958963def _convert_np_array_to_base64 (np_array : np .ndarray ):
959964 im = PIL .Image .fromarray (np_array )
960965 buffered = BytesIO ()
961966 im .save (buffered , format = "PNG" )
962967 img_b64 = base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
963968 return img_b64
964969
965- def _update_hint_labeling (hint_labeling : HintLabeling , action : str , agent : Agent , step_info : StepInfo ):
970+
971+ def _update_hint_labeling (
972+ hint_labeling : HintLabeling , action : str , agent : Agent , step_info : StepInfo
973+ ):
966974 """Update the hint labeling with the action and agent info."""
967975 context = HintLabelingInputs (
968- goal = step_info .obs .get ("goal" , "" ), # TODO: is this goal deprecated?
976+ goal = step_info .obs .get ("goal" , "" ), # TODO: is this goal deprecated?
969977 error_feedback = step_info .obs .get ("last_action_error" , "" ),
970- screenshot = _convert_np_array_to_base64 (step_info .obs ["screenshot" ]),
971- axtree = step_info .obs ["axtree_txt" ],
972- history = [], # TODO: add history
973- hint = "" ,
974- suggestions = [
978+ screenshot = _convert_np_array_to_base64 (step_info .obs ["screenshot" ]),
979+ axtree = step_info .obs ["axtree_txt" ],
980+ history = [], # TODO: add history
981+ hint = "" ,
982+ suggestions = [
975983 {
976984 "id" : "1" ,
977985 "action" : action ,
978986 "think" : step_info .agent_info .think ,
979987 }
980- ]
988+ ],
981989 )
982990 while True :
983991 # update hint labeling ui context
@@ -1001,17 +1009,19 @@ def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent
10011009 think = step_info .agent_info .think
10021010 if action not in seen_actions :
10031011 seen_actions .add (action )
1004- suggestions .append ({"id" : str (len (seen_actions )), "action" : action , "think" : think })
1012+ suggestions .append (
1013+ {"id" : str (len (seen_actions )), "action" : action , "think" : think }
1014+ )
10051015
10061016 # update context
10071017 context = HintLabelingInputs (
10081018 goal = context .goal ,
10091019 error_feedback = context .error_feedback ,
1010- screenshot = context .screenshot ,
1011- axtree = context .axtree ,
1012- history = context .history ,
1013- hint = hint ,
1014- suggestions = suggestions
1020+ screenshot = context .screenshot ,
1021+ axtree = context .axtree ,
1022+ history = context .history ,
1023+ hint = hint ,
1024+ suggestions = suggestions ,
10151025 )
10161026 continue
10171027
@@ -1022,7 +1032,7 @@ def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent
10221032 return action
10231033 else :
10241034 raise ValueError (f"Unknown response type: { response ['type' ]} " )
1025-
1035+
10261036
10271037def _flatten_dict (d , parent_key = "" , sep = "." ):
10281038 """Recursively flatten a nested dictionary."""
0 commit comments