Skip to content

Commit 9b238bc

Browse files
formatting
1 parent daee88e commit 9b238bc

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

src/agentlab/experiments/loop.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,12 @@ def run(self):
460460
_send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
461461
logger.debug("Chat info sent.")
462462

463-
if hasattr(env.unwrapped, "hint_labeling") and isinstance(env.unwrapped.hint_labeling, HintLabeling):
464-
action = _update_hint_labeling(env.unwrapped.hint_labeling, action, agent, step_info)
463+
if hasattr(env.unwrapped, "hint_labeling") and isinstance(
464+
env.unwrapped.hint_labeling, HintLabeling
465+
):
466+
action = _update_hint_labeling(
467+
env.unwrapped.hint_labeling, action, agent, step_info
468+
)
465469

466470
if action is None:
467471
logger.debug("Agent returned None action. Ending episode.")
@@ -955,29 +959,33 @@ def _send_chat_info(chat: Chat, action: str, agent_info: dict):
955959
logger.info(msg)
956960
chat.add_message(role="info", msg=msg)
957961

962+
958963
def _convert_np_array_to_base64(np_array: np.ndarray):
959964
im = PIL.Image.fromarray(np_array)
960965
buffered = BytesIO()
961966
im.save(buffered, format="PNG")
962967
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
963968
return img_b64
964969

965-
def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent, step_info: StepInfo):
970+
971+
def _update_hint_labeling(
972+
hint_labeling: HintLabeling, action: str, agent: Agent, step_info: StepInfo
973+
):
966974
"""Update the hint labeling with the action and agent info."""
967975
context = HintLabelingInputs(
968-
goal=step_info.obs.get("goal", ""), # TODO: is this goal deprecated?
976+
goal=step_info.obs.get("goal", ""), # TODO: is this goal deprecated?
969977
error_feedback=step_info.obs.get("last_action_error", ""),
970-
screenshot = _convert_np_array_to_base64(step_info.obs["screenshot"]),
971-
axtree = step_info.obs["axtree_txt"],
972-
history = [], # TODO: add history
973-
hint = "",
974-
suggestions = [
978+
screenshot=_convert_np_array_to_base64(step_info.obs["screenshot"]),
979+
axtree=step_info.obs["axtree_txt"],
980+
history=[], # TODO: add history
981+
hint="",
982+
suggestions=[
975983
{
976984
"id": "1",
977985
"action": action,
978986
"think": step_info.agent_info.think,
979987
}
980-
]
988+
],
981989
)
982990
while True:
983991
# update hint labeling ui context
@@ -1001,17 +1009,19 @@ def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent
10011009
think = step_info.agent_info.think
10021010
if action not in seen_actions:
10031011
seen_actions.add(action)
1004-
suggestions.append({"id": str(len(seen_actions)), "action": action, "think": think})
1012+
suggestions.append(
1013+
{"id": str(len(seen_actions)), "action": action, "think": think}
1014+
)
10051015

10061016
# update context
10071017
context = HintLabelingInputs(
10081018
goal=context.goal,
10091019
error_feedback=context.error_feedback,
1010-
screenshot = context.screenshot,
1011-
axtree = context.axtree,
1012-
history = context.history,
1013-
hint = hint,
1014-
suggestions = suggestions
1020+
screenshot=context.screenshot,
1021+
axtree=context.axtree,
1022+
history=context.history,
1023+
hint=hint,
1024+
suggestions=suggestions,
10151025
)
10161026
continue
10171027

@@ -1022,7 +1032,7 @@ def _update_hint_labeling(hint_labeling: HintLabeling, action: str, agent: Agent
10221032
return action
10231033
else:
10241034
raise ValueError(f"Unknown response type: {response['type']}")
1025-
1035+
10261036

10271037
def _flatten_dict(d, parent_key="", sep="."):
10281038
"""Recursively flatten a nested dictionary."""

0 commit comments

Comments
 (0)