Skip to content

Commit 614d12d

Browse files
implement save feature to save traces and hints
1 parent 96263ae commit 614d12d

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import base64
22
import copy
33
import importlib
4+
import json
45
import logging
6+
import os
57
from datetime import datetime
68
from io import BytesIO
79

@@ -48,6 +50,14 @@ def filter(self, record):
4850
streamlit_logger.setLevel(logging.ERROR)
4951

5052

53+
def is_json_serializable(value):
54+
try:
55+
json.dumps(value)
56+
return True
57+
except (TypeError, OverflowError):
58+
return False
59+
60+
5161
def get_import_path(obj):
5262
return f"{obj.__module__}.{obj.__qualname__}"
5363

@@ -596,14 +606,21 @@ def set_controller():
596606
st.rerun()
597607

598608

599-
def display_image(img_arr):
609+
def get_base64_serialized_image(img_arr):
600610
if isinstance(img_arr, list):
601611
img_arr = np.array(img_arr)
602612
if isinstance(img_arr, np.ndarray):
603613
im = PIL.Image.fromarray(img_arr)
604614
buffered = BytesIO()
605615
im.save(buffered, format="PNG")
606616
img_b64 = base64.b64encode(buffered.getvalue()).decode()
617+
return img_b64
618+
return None
619+
620+
621+
def display_image(img_arr):
622+
img_b64 = get_base64_serialized_image(img_arr)
623+
if img_b64:
607624
st.markdown(
608625
f'<div style="display: flex; justify-content: center;"><img src="data:image/png;base64,{img_b64}" style="max-width: 80vw; height: auto;" /></div>',
609626
unsafe_allow_html=True,
@@ -644,11 +661,56 @@ def set_previous_steps_tab():
644661
st.code(st.session_state.action_history[i], language=None, wrap_lines=True)
645662

646663

664+
def set_save_tab():
665+
# dump full session_state to json
666+
save_dir = st.text_input("Save Directory", value="~/Downloads")
667+
save_dir = os.path.expanduser(save_dir)
668+
if st.button("Save Session State for Current Run"):
669+
now_str = datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
670+
filename = f"agentlab_controller_state_{now_str}.json"
671+
672+
# prepare payload for saving
673+
payload = {}
674+
payload["timestamp"] = now_str
675+
payload["benchmark"] = st.session_state.benchmark
676+
payload["task"] = st.session_state.task
677+
payload["subtask"] = st.session_state.subtask
678+
payload["agent_args"] = {
679+
k: v for k, v in vars(st.session_state.agent_args).items() if is_json_serializable(v)
680+
}
681+
payload["agent_flags"] = {
682+
k: v for k, v in vars(st.session_state.agent.flags).items() if is_json_serializable(v)
683+
}
684+
payload["agent_flags"]["obs"] = {
685+
k: v
686+
for k, v in vars(st.session_state.agent.flags.obs).items()
687+
if is_json_serializable(v)
688+
}
689+
payload["agent_flags"]["action"] = {
690+
k: v
691+
for k, v in vars(st.session_state.agent.flags.action).items()
692+
if is_json_serializable(v)
693+
}
694+
payload["goal"] = st.session_state.last_obs["goal"]
695+
payload["steps"] = []
696+
for i in range(len(st.session_state.action_history)):
697+
step = {}
698+
step["action"] = st.session_state.action_history[i]
699+
step["thought"] = st.session_state.thought_history[i]
700+
step["prompt"] = st.session_state.prompt_history[i]
701+
step["screenshot"] = get_base64_serialized_image(st.session_state.screenshot_history[i])
702+
step["axtree"] = st.session_state.axtree_history[i]
703+
payload["steps"].append(step)
704+
705+
with open(os.path.join(save_dir, filename), "w") as f:
706+
json.dump(payload, f)
707+
708+
647709
def set_info_tabs():
648710
# Display only if everything is now ready
649711
if len(st.session_state.action_history) > 1:
650-
screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab = st.tabs(
651-
["Screenshot", "AxTree", "Prompt", "Previous Steps"]
712+
screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab, save_tab = st.tabs(
713+
["Screenshot", "AxTree", "Prompt", "Previous Steps", "Save"]
652714
)
653715
else:
654716
screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"])
@@ -662,6 +724,8 @@ def set_info_tabs():
662724
if len(st.session_state.action_history) > 1:
663725
with previous_steps_tab:
664726
set_previous_steps_tab()
727+
with save_tab:
728+
set_save_tab()
665729

666730

667731
def run_streamlit():

0 commit comments

Comments
 (0)