Skip to content

Commit ecab759

Browse files
minor refactoring
1 parent 4163ab6 commit ecab759

File tree

1 file changed

+125
-101
lines changed

1 file changed

+125
-101
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 125 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class Constants:
3636
STATUS = "status"
3737
STATUS_SUCCESS = "success"
3838
STATUS_ERROR = "error"
39+
MESSAGE = "message"
3940

4041
OBS = "obs"
4142
SCREENSHOT = "screenshot"
@@ -378,7 +379,7 @@ def reset_environment():
378379
if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS:
379380
logger.error(resp.status_code)
380381
logger.error(resp.json()[Constants.STATUS])
381-
logger.error(resp.json()["message"])
382+
logger.error(resp.json()[Constants.MESSAGE])
382383
response_json = resp.json()
383384
response_json = deserialize_response(response_json)
384385
obs = response_json[Constants.OBS]
@@ -394,7 +395,7 @@ def reload_task():
394395
if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS:
395396
logger.error(resp.status_code)
396397
logger.error(resp.json()[Constants.STATUS])
397-
logger.error(resp.json()["message"])
398+
logger.error(resp.json()[Constants.MESSAGE])
398399
end = datetime.now()
399400
logger.info(f"Done in {end - start}")
400401

@@ -407,7 +408,7 @@ def step_environment(action):
407408
if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS:
408409
logger.error(resp.status_code)
409410
logger.error(resp.json()[Constants.STATUS])
410-
logger.error(resp.json()["message"])
411+
logger.error(resp.json()[Constants.MESSAGE])
411412
response_json = resp.json()
412413
response_json = deserialize_response(response_json)
413414
if st.session_state.agent.obs_preprocessor:
@@ -577,119 +578,142 @@ def set_prompt_modifier():
577578
st.session_state.agent.flags.extra_instructions = extra_instructions
578579

579580

581+
def set_go_back_to_step_k_section():
582+
with st.container(border=True):
583+
st.markdown("**Go Back to Step K**")
584+
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
585+
is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1
586+
with col1:
587+
step = st.number_input(
588+
"Step",
589+
value=1,
590+
min_value=1,
591+
max_value=len(st.session_state.action_history),
592+
disabled=is_go_back_to_step_k_disabled,
593+
)
594+
with col2:
595+
if st.button(
596+
"Go Back",
597+
help="Go back to step K",
598+
use_container_width=True,
599+
disabled=is_go_back_to_step_k_disabled,
600+
):
601+
reset_agent_state()
602+
restore_agent_history(step=step)
603+
restore_env_history(step=step)
604+
restore_environment()
605+
st.rerun()
606+
607+
608+
def set_reprompt_k_times_section():
609+
with st.container(border=True):
610+
st.markdown("**Reprompt K Times**")
611+
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
612+
with col1:
613+
k = st.number_input(
614+
"Number of Generations",
615+
value=5,
616+
min_value=1,
617+
max_value=25,
618+
)
619+
with col2:
620+
has_clicked_reprompt = st.button(
621+
"Reprompt",
622+
help="Reprompt the agent K times to get a distribution of actions to take",
623+
use_container_width=True,
624+
)
625+
if has_clicked_reprompt:
626+
reprompt_actions = []
627+
with st.spinner(f"Reprompting {k} times"):
628+
for i in range(k):
629+
revert_agent_history()
630+
revert_agent_state()
631+
get_action()
632+
reprompt_actions.append(st.session_state.action)
633+
# show all unique actions found in reprompt actions along with their probability
634+
unique_actions_counter = Counter(reprompt_actions)
635+
unique_actions = sorted(
636+
unique_actions_counter.items(), key=lambda x: x[1], reverse=True
637+
)
638+
st.markdown("**Unique Actions**")
639+
for action, count in unique_actions:
640+
selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)")
641+
if selected_action:
642+
step_environment(action)
643+
st.rerun()
644+
645+
646+
def set_act_k_times_section():
647+
with st.container(border=True):
648+
st.markdown("**Act K Times**")
649+
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
650+
with col1:
651+
k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10)
652+
with col2:
653+
has_clicked_act = st.button(
654+
"Act",
655+
help="Let the agent autonomously perform actions for K steps",
656+
use_container_width=True,
657+
)
658+
if has_clicked_act:
659+
with st.spinner(f"Acting {k} times"):
660+
for _ in range(k):
661+
get_action()
662+
step_environment(st.session_state.action)
663+
st.rerun()
664+
665+
580666
def set_advanced_controller():
581667
with st.expander("**Advanced**", expanded=False):
582668
col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1])
583669
with col_go_back_to:
584-
with st.container(border=True):
585-
st.markdown("**Go Back to Step K**")
586-
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
587-
is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1
588-
with col1:
589-
step = st.number_input(
590-
"Step",
591-
value=1,
592-
min_value=1,
593-
max_value=len(st.session_state.action_history),
594-
disabled=is_go_back_to_step_k_disabled,
595-
)
596-
with col2:
597-
if st.button(
598-
"Go Back",
599-
help="Go back to step K",
600-
use_container_width=True,
601-
disabled=is_go_back_to_step_k_disabled,
602-
):
603-
reset_agent_state()
604-
restore_agent_history(step=step)
605-
restore_env_history(step=step)
606-
restore_environment()
607-
st.rerun()
670+
set_go_back_to_step_k_section()
608671
with col_reprompt_k:
609-
with st.container(border=True):
610-
st.markdown("**Reprompt K Times**")
611-
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
612-
with col1:
613-
k = st.number_input(
614-
"Number of Generations",
615-
value=5,
616-
min_value=1,
617-
max_value=25,
618-
)
619-
with col2:
620-
has_clicked_reprompt = st.button(
621-
"Reprompt",
622-
help="Reprompt the agent K times to get a distribution of actions to take",
623-
use_container_width=True,
624-
)
625-
if has_clicked_reprompt:
626-
reprompt_actions = []
627-
with st.spinner(f"Reprompting {k} times"):
628-
for i in range(k):
629-
revert_agent_history()
630-
revert_agent_state()
631-
get_action()
632-
reprompt_actions.append(st.session_state.action)
633-
# show all unique actions found in reprompt actions along with their probability
634-
unique_actions_counter = Counter(reprompt_actions)
635-
unique_actions = sorted(
636-
unique_actions_counter.items(), key=lambda x: x[1], reverse=True
637-
)
638-
st.markdown("**Unique Actions**")
639-
for action, count in unique_actions:
640-
selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)")
641-
if selected_action:
642-
step_environment(action)
643-
st.rerun()
644-
672+
set_reprompt_k_times_section()
645673
with col_act_k:
646-
with st.container(border=True):
647-
st.markdown("**Act K Times**")
648-
col1, col2 = st.columns([1, 1], vertical_alignment="bottom")
649-
with col1:
650-
k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10)
651-
with col2:
652-
has_clicked_act = st.button(
653-
"Act",
654-
help="Let the agent autonomously perform actions for K steps",
655-
use_container_width=True,
656-
)
657-
if has_clicked_act:
658-
with st.spinner(f"Acting {k} times"):
659-
for _ in range(k):
660-
get_action()
661-
step_environment(st.session_state.action)
662-
st.rerun()
674+
set_act_k_times_section()
675+
676+
677+
def set_previous_step_section():
678+
prev_disabled = len(st.session_state.action_history) <= 1
679+
if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True):
680+
if not prev_disabled:
681+
st.session_state.action = (
682+
None
683+
if len(st.session_state.action_history) == 0
684+
else st.session_state.action_history[-1]
685+
)
686+
reset_agent_state()
687+
revert_agent_history()
688+
revert_env_history()
689+
restore_environment()
690+
st.rerun()
691+
692+
693+
def set_regenerate_action_section():
694+
if st.button("🔄 Regenerate Action", use_container_width=True):
695+
revert_agent_history()
696+
revert_agent_state()
697+
get_action()
698+
st.rerun()
699+
700+
701+
def set_next_step_section():
702+
if st.button("➡️ Next Step", use_container_width=True):
703+
step_environment(st.session_state.action)
704+
st.rerun()
663705

664706

665707
def set_controller():
666708
set_agent_state_box()
667709
set_prompt_modifier()
668710
col_prev, col_redo, col_next = st.columns([1, 1, 1])
669711
with col_prev:
670-
prev_disabled = len(st.session_state.action_history) <= 1
671-
if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True):
672-
if not prev_disabled:
673-
st.session_state.action = (
674-
None
675-
if len(st.session_state.action_history) == 0
676-
else st.session_state.action_history[-1]
677-
)
678-
reset_agent_state()
679-
revert_agent_history()
680-
revert_env_history()
681-
restore_environment()
682-
st.rerun()
712+
set_previous_step_section()
683713
with col_redo:
684-
if st.button("🔄 Regenerate Action", use_container_width=True):
685-
revert_agent_history()
686-
revert_agent_state()
687-
get_action()
688-
st.rerun()
714+
set_regenerate_action_section()
689715
with col_next:
690-
if st.button("➡️ Next Step", use_container_width=True):
691-
step_environment(st.session_state.action)
692-
st.rerun()
716+
set_next_step_section()
693717
set_advanced_controller()
694718

695719

0 commit comments

Comments
 (0)