Skip to content

Commit 7c9dc4c

Browse files
Added hint tuning tool to evaluate best actions at a given step
1 parent 0b1eb18 commit 7c9dc4c

File tree

1 file changed

+123
-12
lines changed

1 file changed

+123
-12
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 123 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
from logging import warning
88
from pathlib import Path
99

10+
from agentlab.llm.chat_api import (
11+
AzureChatModel,
12+
OpenAIChatModel,
13+
OpenRouterChatModel,
14+
make_system_message,
15+
make_user_message,
16+
)
17+
1018
import gradio as gr
1119
import matplotlib.patches as patches
1220
import matplotlib.pyplot as plt
@@ -15,7 +23,7 @@
1523
from attr import dataclass
1624
from browsergym.experiments.loop import StepInfo as BGymStepInfo
1725
from langchain.schema import BaseMessage, HumanMessage
18-
from openai import OpenAI
26+
from openai import AzureOpenAI
1927
from openai.types.responses import ResponseFunctionToolCall
2028
from PIL import Image
2129

@@ -399,14 +407,27 @@ def run_gradio(results_dir: Path):
399407
interactive=True,
400408
elem_id="prompt_tests_textbox",
401409
)
402-
submit_button = gr.Button(value="Submit")
410+
with gr.Row():
411+
num_queries_input = gr.Number(
412+
value=3,
413+
label="Number of model queries",
414+
minimum=1,
415+
maximum=10,
416+
step=1,
417+
precision=0,
418+
interactive=True,
419+
)
420+
submit_button = gr.Button(value="Submit")
403421
result_box = gr.Textbox(
404-
value="", label="Result", show_label=True, interactive=False
422+
value="", label="Result", show_label=True, interactive=False, max_lines=20
405423
)
424+
with gr.Row():
425+
# Add plot component for action distribution graph
426+
action_plot = gr.Plot(label="Action Distribution", show_label=True)
406427

407428
# Define the interaction
408429
submit_button.click(
409-
fn=submit_action, inputs=prompt_tests_textbox, outputs=result_box
430+
fn=submit_action, inputs=[prompt_tests_textbox, num_queries_input], outputs=[result_box, action_plot]
410431
)
411432

412433
# Handle Events #
@@ -843,9 +864,11 @@ def _page_to_iframe(page: str):
843864
return page
844865

845866

846-
def submit_action(input_text):
867+
def submit_action(input_text, num_queries=3):
847868
global info
848869
agent_info = info.exp_result.steps_info[info.step].agent_info
870+
# Get the current step's action string for comparison
871+
step_action_str = info.exp_result.steps_info[info.step].action
849872
chat_messages = deepcopy(agent_info.get("chat_messages", ["No Chat Messages"])[:2])
850873
if isinstance(chat_messages[1], BaseMessage): # TODO remove once langchain is deprecated
851874
assert isinstance(chat_messages[1], HumanMessage), "Second message should be user"
@@ -858,14 +881,102 @@ def submit_action(input_text):
858881
else:
859882
raise ValueError("Chat messages should be a list of BaseMessage or dict")
860883

861-
client = OpenAI()
884+
client = AzureChatModel(model_name="gpt-35-turbo", deployment_name="gpt-35-turbo")
862885
chat_messages[1]["content"] = input_text
863-
completion = client.chat.completions.create(
864-
model="gpt-4o-mini",
865-
messages=chat_messages,
866-
)
867-
result_text = completion.choices[0].message.content
868-
return result_text
886+
887+
# Query the model N times
888+
answers = []
889+
actions = []
890+
import re
891+
892+
for _ in range(num_queries):
893+
answer = client(chat_messages)
894+
content = answer.get("content", "")
895+
answers.append(content)
896+
897+
# Extract action part using regex
898+
action_match = re.search(r'<action>(.*?)</action>', content, re.DOTALL)
899+
if action_match:
900+
actions.append(action_match.group(1).strip())
901+
902+
# Prepare the aggregate result
903+
result = ""
904+
905+
# Include full responses first
906+
result += "\n\n===== FULL MODEL RESPONSES =====\n\n"
907+
result += "\n\n===== MODEL RESPONSE SEPARATION =====\n\n".join(answers)
908+
909+
# Then add aggregated actions
910+
result += "\n\n===== EXTRACTED ACTIONS =====\n\n"
911+
912+
# Create plot for action distribution
913+
import matplotlib.pyplot as plt
914+
import numpy as np
915+
from collections import Counter
916+
917+
# Create a figure for the action distribution
918+
fig = plt.figure(figsize=(10, 6))
919+
920+
if actions:
921+
# Count unique actions
922+
action_counts = Counter(actions)
923+
924+
# Get actions in most_common order to ensure consistency between plot and text output
925+
most_common_actions = action_counts.most_common()
926+
927+
# Prepare data for plotting (using most_common order)
928+
labels = [f"Action {i+1}" for i in range(len(most_common_actions))]
929+
values = [count for _, count in most_common_actions]
930+
percentages = [(count / len(actions)) * 100 for count in values]
931+
932+
# Create bar chart
933+
plt.bar(labels, percentages, color='skyblue')
934+
plt.xlabel('Actions')
935+
plt.ylabel('Percentage (%)')
936+
plt.title(f'Action Distribution (from {num_queries} model queries)')
937+
plt.ylim(0, 100) # Set y-axis from 0 to 100%
938+
939+
# Add percentage labels on top of each bar
940+
for i, v in enumerate(percentages):
941+
plt.text(i, v + 2, f"{v:.1f}%", ha='center')
942+
943+
# Add total counts as text annotation
944+
plt.figtext(0.5, 0.01,
945+
f"Total actions extracted: {len(actions)} | Unique actions: {len(action_counts)}",
946+
ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.5, "pad":5})
947+
948+
# Display unique actions and their counts in text result
949+
for i, (action, count) in enumerate(action_counts.most_common()):
950+
percentage = (count / len(actions)) * 100
951+
952+
# Check if this action matches the current step's action
953+
matches_current_action = step_action_str and action.strip() == step_action_str.strip()
954+
955+
# Highlight conditions:
956+
# 1. If it's the most common action (i==0)
957+
# 2. If it matches the current step's action
958+
if i == 0 and matches_current_action:
959+
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n"
960+
elif i == 0: # Just the most common
961+
result += f"** Predicted Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%)**:\n**{action}**\n\n"
962+
elif matches_current_action: # Matches current action but not most common
963+
result += f"** Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%) [MATCHES CURRENT ACTION]**:\n**{action}**\n\n"
964+
else: # Regular action
965+
result += f"Action {i+1} (occurred {count}/{len(actions)} times - {percentage:.1f}%):\n{action}\n\n"
966+
else:
967+
result += "No actions found in any of the model responses.\n\n"
968+
969+
# Create empty plot with message
970+
plt.text(0.5, 0.5, "No actions found in model responses",
971+
ha='center', va='center', fontsize=14)
972+
plt.axis('off') # Hide axes
973+
974+
plt.tight_layout()
975+
976+
# Return both the text result and the figure
977+
return result, fig
978+
979+
869980

870981

871982
def update_prompt_tests():

0 commit comments

Comments
 (0)