Skip to content

Commit 7e17d66

Browse files
add button to go back to arbitrary past step
1 parent 250141f commit 7e17d66

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,19 @@ def revert_agent_state():
127127
st.session_state.agent.memories.pop()
128128

129129

130+
def restore_env_history(step: int):
131+
st.session_state.obs_history = st.session_state.obs_history[:step]
132+
st.session_state.screenshot_history = st.session_state.screenshot_history[:step]
133+
st.session_state.axtree_history = st.session_state.axtree_history[:step]
134+
135+
136+
def restore_agent_history(step: int):
137+
st.session_state.action_history = st.session_state.action_history[:step]
138+
st.session_state.action_info_history = st.session_state.action_info_history[:step]
139+
st.session_state.thought_history = st.session_state.thought_history[:step]
140+
st.session_state.prompt_history = st.session_state.prompt_history[:step]
141+
142+
130143
def get_prompt(info):
131144
if info is not None and isinstance(info.chat_messages, Discussion):
132145
chat_messages = info.chat_messages.messages
@@ -612,6 +625,12 @@ def set_prompt_tab():
612625
def set_previous_steps_tab():
613626
for i in range(len(st.session_state.action_history) - 1):
614627
with st.expander(f"### Step {i + 1}", expanded=False):
628+
if st.button(f"Go back to step {i + 1}"):
629+
reset_agent_state()
630+
restore_agent_history(step=i + 1)
631+
restore_env_history(step=i + 1)
632+
restore_environment()
633+
st.rerun()
615634
screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"])
616635
with screenshot_tab:
617636
display_image(st.session_state.screenshot_history[i])
@@ -626,12 +645,6 @@ def set_previous_steps_tab():
626645

627646

628647
def set_info_tabs():
629-
print(len(st.session_state.action_history))
630-
print(len(st.session_state.screenshot_history))
631-
print(len(st.session_state.axtree_history))
632-
print(len(st.session_state.prompt_history))
633-
print(len(st.session_state.thought_history))
634-
print("---")
635648
# Display only if everything is now ready
636649
if len(st.session_state.action_history) > 1:
637650
screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab = st.tabs(

0 commit comments

Comments
 (0)