77from logging import warning
88from pathlib import Path
99
10+ from agentlab .llm .chat_api import (
11+ AzureChatModel ,
12+ OpenAIChatModel ,
13+ OpenRouterChatModel ,
14+ make_system_message ,
15+ make_user_message ,
16+ )
17+
1018import gradio as gr
1119import matplotlib .patches as patches
1220import matplotlib .pyplot as plt
1523from attr import dataclass
1624from browsergym .experiments .loop import StepInfo as BGymStepInfo
1725from langchain .schema import BaseMessage , HumanMessage
18- from openai import OpenAI
26+ from openai import AzureOpenAI
1927from openai .types .responses import ResponseFunctionToolCall
2028from PIL import Image
2129
@@ -399,14 +407,27 @@ def run_gradio(results_dir: Path):
399407 interactive = True ,
400408 elem_id = "prompt_tests_textbox" ,
401409 )
402- submit_button = gr .Button (value = "Submit" )
410+ with gr .Row ():
411+ num_queries_input = gr .Number (
412+ value = 3 ,
413+ label = "Number of model queries" ,
414+ minimum = 1 ,
415+ maximum = 10 ,
416+ step = 1 ,
417+ precision = 0 ,
418+ interactive = True ,
419+ )
420+ submit_button = gr .Button (value = "Submit" )
403421 result_box = gr .Textbox (
404- value = "" , label = "Result" , show_label = True , interactive = False
422+ value = "" , label = "Result" , show_label = True , interactive = False , max_lines = 20
405423 )
424+ with gr .Row ():
425+ # Add plot component for action distribution graph
426+ action_plot = gr .Plot (label = "Action Distribution" , show_label = True )
406427
407428 # Define the interaction
408429 submit_button .click (
409- fn = submit_action , inputs = prompt_tests_textbox , outputs = result_box
430+ fn = submit_action , inputs = [ prompt_tests_textbox , num_queries_input ], outputs = [ result_box , action_plot ]
410431 )
411432
412433 # Handle Events #
@@ -843,9 +864,11 @@ def _page_to_iframe(page: str):
843864 return page
844865
845866
846- def submit_action (input_text ):
867+ def submit_action (input_text , num_queries = 3 ):
847868 global info
848869 agent_info = info .exp_result .steps_info [info .step ].agent_info
870+ # Get the current step's action string for comparison
871+ step_action_str = info .exp_result .steps_info [info .step ].action
849872 chat_messages = deepcopy (agent_info .get ("chat_messages" , ["No Chat Messages" ])[:2 ])
850873 if isinstance (chat_messages [1 ], BaseMessage ): # TODO remove once langchain is deprecated
851874 assert isinstance (chat_messages [1 ], HumanMessage ), "Second message should be user"
@@ -858,14 +881,102 @@ def submit_action(input_text):
858881 else :
859882 raise ValueError ("Chat messages should be a list of BaseMessage or dict" )
860883
861- client = OpenAI ( )
884+ client = AzureChatModel ( model_name = "gpt-35-turbo" , deployment_name = "gpt-35-turbo" )
862885 chat_messages [1 ]["content" ] = input_text
863- completion = client .chat .completions .create (
864- model = "gpt-4o-mini" ,
865- messages = chat_messages ,
866- )
867- result_text = completion .choices [0 ].message .content
868- return result_text
886+
887+ # Query the model N times
888+ answers = []
889+ actions = []
890+ import re
891+
892+ for _ in range (num_queries ):
893+ answer = client (chat_messages )
894+ content = answer .get ("content" , "" )
895+ answers .append (content )
896+
897+ # Extract action part using regex
898+ action_match = re .search (r'<action>(.*?)</action>' , content , re .DOTALL )
899+ if action_match :
900+ actions .append (action_match .group (1 ).strip ())
901+
902+ # Prepare the aggregate result
903+ result = ""
904+
905+ # Include full responses first
906+ result += "\n \n ===== FULL MODEL RESPONSES =====\n \n "
907+ result += "\n \n ===== MODEL RESPONSE SEPARATION =====\n \n " .join (answers )
908+
909+ # Then add aggregated actions
910+ result += "\n \n ===== EXTRACTED ACTIONS =====\n \n "
911+
912+ # Create plot for action distribution
913+ import matplotlib .pyplot as plt
914+ import numpy as np
915+ from collections import Counter
916+
917+ # Create a figure for the action distribution
918+ fig = plt .figure (figsize = (10 , 6 ))
919+
920+ if actions :
921+ # Count unique actions
922+ action_counts = Counter (actions )
923+
924+ # Get actions in most_common order to ensure consistency between plot and text output
925+ most_common_actions = action_counts .most_common ()
926+
927+ # Prepare data for plotting (using most_common order)
928+ labels = [f"Action { i + 1 } " for i in range (len (most_common_actions ))]
929+ values = [count for _ , count in most_common_actions ]
930+ percentages = [(count / len (actions )) * 100 for count in values ]
931+
932+ # Create bar chart
933+ plt .bar (labels , percentages , color = 'skyblue' )
934+ plt .xlabel ('Actions' )
935+ plt .ylabel ('Percentage (%)' )
936+ plt .title (f'Action Distribution (from { num_queries } model queries)' )
937+ plt .ylim (0 , 100 ) # Set y-axis from 0 to 100%
938+
939+ # Add percentage labels on top of each bar
940+ for i , v in enumerate (percentages ):
941+ plt .text (i , v + 2 , f"{ v :.1f} %" , ha = 'center' )
942+
943+ # Add total counts as text annotation
944+ plt .figtext (0.5 , 0.01 ,
945+ f"Total actions extracted: { len (actions )} | Unique actions: { len (action_counts )} " ,
946+ ha = "center" , fontsize = 10 , bbox = {"facecolor" :"white" , "alpha" :0.5 , "pad" :5 })
947+
948+ # Display unique actions and their counts in text result
949+ for i , (action , count ) in enumerate (action_counts .most_common ()):
950+ percentage = (count / len (actions )) * 100
951+
952+ # Check if this action matches the current step's action
953+ matches_current_action = step_action_str and action .strip () == step_action_str .strip ()
954+
955+ # Highlight conditions:
956+ # 1. If it's the most common action (i==0)
957+ # 2. If it matches the current step's action
958+ if i == 0 and matches_current_action :
959+ result += f"** Predicted Action { i + 1 } (occurred { count } /{ len (actions )} times - { percentage :.1f} %) [MATCHES CURRENT ACTION]**:\n **{ action } **\n \n "
960+ elif i == 0 : # Just the most common
961+ result += f"** Predicted Action { i + 1 } (occurred { count } /{ len (actions )} times - { percentage :.1f} %)**:\n **{ action } **\n \n "
962+ elif matches_current_action : # Matches current action but not most common
963+ result += f"** Action { i + 1 } (occurred { count } /{ len (actions )} times - { percentage :.1f} %) [MATCHES CURRENT ACTION]**:\n **{ action } **\n \n "
964+ else : # Regular action
965+ result += f"Action { i + 1 } (occurred { count } /{ len (actions )} times - { percentage :.1f} %):\n { action } \n \n "
966+ else :
967+ result += "No actions found in any of the model responses.\n \n "
968+
969+ # Create empty plot with message
970+ plt .text (0.5 , 0.5 , "No actions found in model responses" ,
971+ ha = 'center' , va = 'center' , fontsize = 14 )
972+ plt .axis ('off' ) # Hide axes
973+
974+ plt .tight_layout ()
975+
976+ # Return both the text result and the figure
977+ return result , fig
978+
979+
869980
870981
871982def update_prompt_tests ():
0 commit comments