|
25 | 25 | ) |
26 | 26 | from agentlab.experiments.exp_utils import RESULTS_DIR |
27 | 27 | from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions |
28 | | -from agentlab.llm.llm_utils import Discussion |
| 28 | +from agentlab.llm.response_api import LLMOutput |
29 | 29 | from bgym import DEFAULT_BENCHMARKS |
30 | 30 | from dotenv import load_dotenv |
31 | 31 | from transformers import AutoTokenizer |
@@ -188,7 +188,12 @@ def step_agent_history(action, action_info): |
188 | 188 | st.session_state.action_history.append(action) |
189 | 189 | st.session_state.action_info_history.append(action_info) |
190 | 190 | st.session_state.thought_history.append(action_info.think) |
191 | | - st.session_state.prompt_history.append(get_prompt(action_info)) |
| 191 | + if isinstance(st.session_state.agent, GenericAgent): |
| 192 | + st.session_state.prompt_history.append(get_prompt(action_info)) |
| 193 | + elif isinstance(st.session_state.agent, ToolUseAgent): |
| 194 | + st.session_state.prompt_history.append( |
| 195 | + "\n".join([elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()]) |
| 196 | + ) |
192 | 197 |
|
193 | 198 | # HACK: memory history can only be obtained via the agent |
194 | 199 | if isinstance(st.session_state.agent, GenericAgent): |
@@ -229,10 +234,31 @@ def revert_agent_history(): |
229 | 234 |
|
230 | 235 | def revert_agent_state(): |
231 | 236 | logger.info("Reverting agent state") |
232 | | - st.session_state.agent.obs_history.pop() |
233 | | - st.session_state.agent.actions.pop() |
234 | | - st.session_state.agent.thoughts.pop() |
235 | | - st.session_state.agent.memories.pop() |
| 237 | + if isinstance(st.session_state.agent, GenericAgent): |
| 238 | + st.session_state.agent.obs_history.pop() |
| 239 | + st.session_state.agent.actions.pop() |
| 240 | + st.session_state.agent.thoughts.pop() |
| 241 | + st.session_state.agent.memories.pop() |
| 242 | + elif isinstance(st.session_state.agent, ToolUseAgent): |
| 243 | + num_groups = len(st.session_state.agent.discussion.groups) |
| 244 | + if num_groups == 3: |
| 245 | + # start from blank state |
| 246 | + st.session_state.agent.discussion.groups = [] |
| 247 | + st.session_state.agent.last_response = LLMOutput() |
| 248 | + st.session_state.agent._responses = [] |
| 249 | + elif num_groups > 3: |
| 250 | + # get rid of the last group (last action), and remove everything from the other previous group except for the action |
| 251 | + st.session_state.agent.discussion.groups.pop() |
| 252 | + last_group = copy.deepcopy(st.session_state.agent.discussion.groups[-1]) |
| 253 | + last_group.summary = None |
| 254 | + last_group.messages = last_group.messages[:0] # remove everything from last group |
| 255 | + st.session_state.agent.discussion.groups[-1] = last_group |
| 256 | + st.session_state.agent._responses.pop() |
| 257 | + st.session_state.agent.last_response = copy.deepcopy( |
| 258 | + st.session_state.agent._responses[-1] |
| 259 | + ) |
| 260 | + else: |
| 261 | + raise Exception("Invalid number of groups") |
236 | 262 |
|
237 | 263 |
|
238 | 264 | def restore_env_history(step: int): |
@@ -534,9 +560,17 @@ def load_session(exp_files): |
534 | 560 | st.session_state.action_history.append(step_info.action) |
535 | 561 | st.session_state.action_info_history.append(step_info.agent_info) |
536 | 562 | st.session_state.thought_history.append(step_info.agent_info.get("think", None)) |
537 | | - st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) |
538 | 563 | if isinstance(st.session_state.agent, GenericAgent): |
539 | 564 | st.session_state.memory_history.append(step_info.agent_info.get("memory", None)) |
| 565 | + st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) |
| 566 | + elif isinstance(st.session_state.agent, ToolUseAgent): |
| 567 | + st.session_state.prompt_history.append( |
| 568 | + "\n".join( |
| 569 | + [elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()] |
| 570 | + ) |
| 571 | + ) |
| 572 | + else: |
| 573 | + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") |
540 | 574 | st.session_state.obs_history.append(step_info.obs) |
541 | 575 | st.session_state.reward_history.append(step_info.reward) |
542 | 576 | st.session_state.terminated_history.append(step_info.terminated) |
@@ -573,7 +607,8 @@ def clean_session(): |
573 | 607 | def prepare_agent(): |
574 | 608 | st.session_state.agent_args.prepare() |
575 | 609 | st.session_state.agent = st.session_state.agent_args.make_agent() |
576 | | - st.session_state.agent.set_task_name(st.session_state.task) |
| 610 | + if isinstance(st.session_state.agent, ToolUseAgent): |
| 611 | + st.session_state.agent.set_task_name(st.session_state.task) |
577 | 612 |
|
578 | 613 |
|
579 | 614 | def set_environment_info(): |
@@ -863,9 +898,9 @@ def set_prompt_modifier(): |
863 | 898 | st.session_state.agent.config.obs.use_tabs = st.checkbox( |
864 | 899 | "Use tabs", value=st.session_state.agent.config.obs.use_tabs |
865 | 900 | ) |
866 | | - st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( |
867 | | - "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer |
868 | | - ) |
| 901 | + # st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( |
| 902 | + # "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer |
| 903 | + # ) |
869 | 904 | st.session_state.agent.config.obs.use_zoomed_webpage = st.checkbox( |
870 | 905 | "Use zoomed webpage", value=st.session_state.agent.config.obs.use_zoomed_webpage |
871 | 906 | ) |
@@ -1107,7 +1142,14 @@ def set_axtree_tab(): |
1107 | 1142 |
|
1108 | 1143 |
|
1109 | 1144 | def set_prompt_tab(): |
1110 | | - st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) |
| 1145 | + if isinstance(st.session_state.agent, GenericAgent): |
| 1146 | + st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) |
| 1147 | + elif isinstance(st.session_state.agent, ToolUseAgent): |
| 1148 | + st.markdown(st.session_state.prompt_history[-1]) |
| 1149 | + |
| 1150 | + st.markdown(f"## Last summary:\n{st.session_state.agent.discussion.get_last_summary()}") |
| 1151 | + else: |
| 1152 | + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") |
1111 | 1153 |
|
1112 | 1154 |
|
1113 | 1155 | def set_previous_steps_tab(): |
|
0 commit comments