|
4 | 4 | import json |
5 | 5 | import logging |
6 | 6 | import os |
| 7 | +from collections import Counter |
7 | 8 | from datetime import datetime |
8 | 9 | from io import BytesIO |
9 | 10 |
|
@@ -576,6 +577,91 @@ def set_prompt_modifier(): |
576 | 577 | st.session_state.agent.flags.extra_instructions = extra_instructions |
577 | 578 |
|
578 | 579 |
|
| 580 | +def set_advanced_controller(): |
| 581 | + with st.expander("**Advanced**", expanded=False): |
| 582 | + col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1]) |
| 583 | + with col_go_back_to: |
| 584 | + with st.container(border=True): |
| 585 | + st.markdown("**Go Back to Step K**") |
| 586 | + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") |
| 587 | + is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1 |
| 588 | + with col1: |
| 589 | + step = st.number_input( |
| 590 | + "Step", |
| 591 | + value=1, |
| 592 | + min_value=1, |
| 593 | + max_value=len(st.session_state.action_history), |
| 594 | + disabled=is_go_back_to_step_k_disabled, |
| 595 | + ) |
| 596 | + with col2: |
| 597 | + if st.button( |
| 598 | + "Go Back", |
| 599 | + help="Go back to step K", |
| 600 | + use_container_width=True, |
| 601 | + disabled=is_go_back_to_step_k_disabled, |
| 602 | + ): |
| 603 | + reset_agent_state() |
| 604 | + restore_agent_history(step=step) |
| 605 | + restore_env_history(step=step) |
| 606 | + restore_environment() |
| 607 | + st.rerun() |
| 608 | + with col_reprompt_k: |
| 609 | + with st.container(border=True): |
| 610 | + st.markdown("**Reprompt K Times**") |
| 611 | + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") |
| 612 | + with col1: |
| 613 | + k = st.number_input( |
| 614 | + "Number of Generations", |
| 615 | + value=5, |
| 616 | + min_value=1, |
| 617 | + max_value=25, |
| 618 | + ) |
| 619 | + with col2: |
| 620 | + has_clicked_reprompt = st.button( |
| 621 | + "Reprompt", |
| 622 | + help="Reprompt the agent K times to get a distribution of actions to take", |
| 623 | + use_container_width=True, |
| 624 | + ) |
| 625 | + if has_clicked_reprompt: |
| 626 | + reprompt_actions = [] |
| 627 | + with st.spinner(f"Reprompting {k} times"): |
| 628 | + for i in range(k): |
| 629 | + revert_agent_history() |
| 630 | + revert_agent_state() |
| 631 | + get_action() |
| 632 | + reprompt_actions.append(st.session_state.action) |
| 633 | + # show all unique actions found in reprompt actions along with their probability |
| 634 | + unique_actions_counter = Counter(reprompt_actions) |
| 635 | + unique_actions = sorted( |
| 636 | + unique_actions_counter.items(), key=lambda x: x[1], reverse=True |
| 637 | + ) |
| 638 | + st.markdown("**Unique Actions**") |
| 639 | + for action, count in unique_actions: |
| 640 | + selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)") |
| 641 | + if selected_action: |
| 642 | + step_environment(action) |
| 643 | + st.rerun() |
| 644 | + |
| 645 | + with col_act_k: |
| 646 | + with st.container(border=True): |
| 647 | + st.markdown("**Act K Times**") |
| 648 | + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") |
| 649 | + with col1: |
| 650 | + k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) |
| 651 | + with col2: |
| 652 | + has_clicked_act = st.button( |
| 653 | + "Act", |
| 654 | + help="Let the agent autonomously perform actions for K steps", |
| 655 | + use_container_width=True, |
| 656 | + ) |
| 657 | + if has_clicked_act: |
| 658 | + with st.spinner(f"Acting {k} times"): |
| 659 | + for _ in range(k): |
| 660 | + get_action() |
| 661 | + step_environment(st.session_state.action) |
| 662 | + st.rerun() |
| 663 | + |
| 664 | + |
579 | 665 | def set_controller(): |
580 | 666 | set_agent_state_box() |
581 | 667 | set_prompt_modifier() |
@@ -604,6 +690,7 @@ def set_controller(): |
604 | 690 | if st.button("➡️ Next Step", use_container_width=True): |
605 | 691 | step_environment(st.session_state.action) |
606 | 692 | st.rerun() |
| 693 | + set_advanced_controller() |
607 | 694 |
|
608 | 695 |
|
609 | 696 | def get_base64_serialized_image(img_arr): |
|
0 commit comments