Skip to content

Commit 3eb1d03

Browse files
authored
Merge pull request #3 from andrewginns/add-levels-of-eval-difficulty
2 parents 07a5b70 + d07902e commit 3eb1d03

File tree

5 files changed

+417
-128
lines changed

5 files changed

+417
-128
lines changed
Lines changed: 70 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import os
3-
import subprocess
43
from typing import Any
54

65
import logfire
@@ -14,8 +13,11 @@
1413

1514
from agents_mcp_usage.multi_mcp.mermaid_diagrams import (
1615
invalid_mermaid_diagram_easy,
16+
invalid_mermaid_diagram_medium,
17+
invalid_mermaid_diagram_hard,
1718
valid_mermaid_diagram,
1819
)
20+
from mcp_servers.mermaid_validator import validate_mermaid_diagram
1921

2022
load_dotenv()
2123

@@ -27,8 +29,8 @@
2729
logfire.instrument_pydantic_ai()
2830

2931
# Default model to use
30-
DEFAULT_MODEL = "gemini-2.5-pro-preview-03-25"
31-
# DEFAULT_MODEL = "openai:o4-mini"
32+
DEFAULT_MODEL = "gemini-2.5-pro-preview-05-06"
33+
3234
# Configure MCP servers
3335
local_server = MCPServerStdio(
3436
command="uv",
@@ -39,53 +41,23 @@
3941
],
4042
)
4143
mermaid_server = MCPServerStdio(
42-
command="npx",
44+
command="uv",
4345
args=[
44-
"-y",
45-
"@rtuin/mcp-mermaid-validator@latest",
46+
"run",
47+
"mcp_servers/mermaid_validator.py",
4648
],
4749
)
4850

4951

5052
# Create Agent with MCP servers
51-
def create_agent(model: str = DEFAULT_MODEL):
53+
def create_agent(model: str = DEFAULT_MODEL, model_settings: dict[str, Any] = {}):
5254
return Agent(
5355
model,
5456
mcp_servers=[local_server, mermaid_server],
57+
model_settings=model_settings,
5558
)
5659

5760

58-
agent = create_agent()
59-
Agent.instrument_all()
60-
61-
62-
async def main(
63-
query: str = "Hi!", request_limit: int = 5, model: str = DEFAULT_MODEL
64-
) -> Any:
65-
"""
66-
Main function to run the agent
67-
68-
Args:
69-
query (str): The query to run the agent with
70-
request_limit (int): The number of requests to make to the MCP servers
71-
model (str): The model to use for the agent
72-
73-
Returns:
74-
The result from the agent's execution
75-
"""
76-
# Create a fresh agent with the specified model
77-
current_agent = create_agent(model)
78-
79-
# Set a request limit for LLM calls
80-
usage_limits = UsageLimits(request_limit=request_limit)
81-
82-
# Invoke the agent with the usage limits
83-
async with current_agent.run_mcp_servers():
84-
result = await current_agent.run(query, usage_limits=usage_limits)
85-
86-
return result
87-
88-
8961
# Define input and output schema for evaluations
9062
class MermaidInput(BaseModel):
9163
invalid_diagram: str
@@ -110,86 +82,35 @@ class MermaidDiagramValid(Evaluator[MermaidInput, MermaidOutput]):
11082
async def evaluate(
11183
self, ctx: EvaluatorContext[MermaidInput, MermaidOutput]
11284
) -> float:
113-
diagram = ctx.output.fixed_diagram
114-
115-
# Extract mermaid code from markdown code block if present
116-
mermaid_code = diagram
117-
if "```mermaid" in diagram and "```" in diagram:
118-
start_idx = diagram.find("```mermaid") + len("```mermaid")
119-
end_idx = diagram.rfind("```")
120-
mermaid_code = diagram[start_idx:end_idx].strip()
121-
122-
# Validate using mmdc
123-
is_valid, _ = self.validate_mermaid_string_via_mmdc(mermaid_code)
124-
return 1.0 if is_valid else 0.0
125-
126-
def validate_mermaid_string_via_mmdc(
127-
self, mermaid_code: str, mmdc_path: str = "mmdc"
128-
) -> tuple[bool, str]:
129-
"""
130-
Validates a Mermaid string by attempting to compile it using the
131-
Mermaid CLI (mmdc). Requires mmdc to be installed and in PATH,
132-
or mmdc_path to be explicitly provided.
133-
134-
Args:
135-
mermaid_code: The string containing the Mermaid diagram syntax.
136-
mmdc_path: The command or path to the mmdc executable.
137-
138-
Returns:
139-
A tuple (is_valid: bool, message: str).
140-
'message' will contain stderr output if not valid, or a success message.
141-
"""
142-
# Define temporary file names
143-
temp_mmd_file = "temp_mermaid_for_validation.mmd"
144-
# mmdc requires an output file, even if we don't use its content for validation.
145-
temp_output_file = "temp_mermaid_output.svg"
146-
147-
# Write the mermaid code to a temporary file
148-
with open(temp_mmd_file, "w", encoding="utf-8") as f:
149-
f.write(mermaid_code)
150-
151-
try:
152-
# Construct the command to run mmdc
153-
command = [mmdc_path, "-i", temp_mmd_file, "-o", temp_output_file]
154-
155-
# Execute the mmdc command
156-
process = subprocess.run(
157-
command,
158-
capture_output=True, # Capture stdout and stderr
159-
text=True, # Decode output as text
160-
check=False, # Do not raise an exception for non-zero exit codes
161-
encoding="utf-8",
85+
# Strip whitespace, remove backticks and ```mermaid markers
86+
input_str = ctx.output.fixed_diagram.strip()
87+
88+
# Remove ```mermaid and ``` markers
89+
if input_str.startswith("```mermaid"):
90+
input_str = input_str[len("```mermaid") :].strip()
91+
if input_str.endswith("```"):
92+
input_str = input_str[:-3].strip()
93+
94+
# Remove any remaining backticks
95+
input_str = input_str.replace("`", "")
96+
97+
logfire.info(
98+
"Evaluating mermaid diagram validity",
99+
diagram_length=len(input_str),
100+
diagram_preview=input_str[:100],
101+
)
102+
103+
# Use the MCP server's validation function
104+
result = await validate_mermaid_diagram(input_str)
105+
106+
if result.is_valid:
107+
logfire.info("Mermaid diagram validation succeeded")
108+
else:
109+
logfire.warning(
110+
"Mermaid diagram validation failed", error_message=result.error_message
162111
)
163112

164-
if process.returncode == 0:
165-
return True, "Syntax appears valid (compiled successfully by mmdc)."
166-
else:
167-
# mmdc usually prints errors to stderr.
168-
error_message = process.stderr.strip()
169-
# Sometimes, syntax errors might also appear in stdout for certain mmdc versions or error types
170-
if not error_message and process.stdout.strip():
171-
error_message = process.stdout.strip()
172-
return (
173-
False,
174-
f"Invalid syntax or mmdc error (exit code {process.returncode}):\n{error_message}",
175-
)
176-
except FileNotFoundError:
177-
return False, (
178-
f"Validation failed: '{mmdc_path}' command not found. "
179-
"Please ensure Mermaid CLI (mmdc) is installed and in your system's PATH, "
180-
"or provide the full path to the executable."
181-
)
182-
except Exception as e:
183-
return (
184-
False,
185-
f"Validation failed due to an unexpected error during mmdc execution: {e}",
186-
)
187-
finally:
188-
# Clean up the temporary files
189-
if os.path.exists(temp_mmd_file):
190-
os.remove(temp_mmd_file)
191-
if os.path.exists(temp_output_file):
192-
os.remove(temp_output_file)
113+
return 1.0 if result.is_valid else 0.0
193114

194115

195116
async def fix_mermaid_diagram(
@@ -206,9 +127,15 @@ async def fix_mermaid_diagram(
206127
"""
207128
query = f"Add the current time and fix the mermaid diagram syntax using the validator: {inputs.invalid_diagram}. Return only the fixed mermaid diagram between backticks."
208129

209-
result = await main(query, model=model)
130+
# Create a fresh agent for each invocation to avoid concurrent usage issues
131+
current_agent = create_agent(model)
132+
usage_limits = UsageLimits(request_limit=5)
210133

211-
# Extract the mermaid diagram from the output
134+
# Use the agent's context manager directly in this function
135+
async with current_agent.run_mcp_servers():
136+
result = await current_agent.run(query, usage_limits=usage_limits)
137+
138+
# Extract the mermaid diagram from the result output
212139
output = result.output
213140

214141
# Logic to extract the diagram from between backticks
@@ -232,12 +159,25 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
232159
The evaluation dataset
233160
"""
234161
return Dataset[MermaidInput, MermaidOutput, Any](
162+
# Construct 3 tests, each asks the LLM to fix an invalid mermaid diagram of increasing difficulty
235163
cases=[
236164
Case(
237-
name="fix_invalid_diagram_1",
165+
name="fix_invalid_diagram_easy",
238166
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_easy),
239167
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
240-
metadata={"test_type": "mermaid_easy_fix", "iteration": 1},
168+
metadata={"test_type": "mermaid_easy_fix"},
169+
),
170+
Case(
171+
name="fix_invalid_diagram_medium",
172+
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_medium),
173+
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
174+
metadata={"test_type": "mermaid_medium_fix"},
175+
),
176+
Case(
177+
name="fix_invalid_diagram_hard",
178+
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_hard),
179+
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
180+
metadata={"test_type": "mermaid_hard_fix"},
241181
),
242182
],
243183
evaluators=[
@@ -249,9 +189,9 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
249189
model=judge_model,
250190
),
251191
LLMJudge(
252-
rubric="The fixed diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
192+
rubric="The output diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
253193
+ "Check if nodes, connections, and labels are preserved."
254-
+ "The current time should be placeholder should be replace with a datetime",
194+
+ "The current time should be placeholder should be replace with a valid datetime",
255195
include_input=False,
256196
model=judge_model,
257197
),
@@ -276,20 +216,24 @@ async def fix_with_model(inputs: MermaidInput) -> MermaidOutput:
276216
return await fix_mermaid_diagram(inputs, model=model)
277217

278218
report = await dataset.evaluate(
279-
fix_with_model, name=f"{model}-multi-mcp-mermaid-diagram-fix-evals"
219+
fix_with_model,
220+
name=f"{model}-multi-mcp-mermaid-diagram-fix-evals",
221+
max_concurrency=1, # Run one evaluation at a time
280222
)
281223

282-
report.print(include_input=True, include_output=True)
224+
report.print(include_input=False, include_output=False)
283225
return report
284226

285227

286228
if __name__ == "__main__":
287229
# You can use different models for the agent and the judge
288-
agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
230+
# agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
231+
agent_model = "gemini-2.5-flash-preview-04-17"
232+
# agent_model = "openai:o4-mini"
233+
# agent_model = "gemini-2.5-flash-preview-04-17"
289234
judge_model = os.getenv("JUDGE_MODEL", DEFAULT_MODEL)
290235

291236
async def run_all():
292-
# Run evaluations
293237
await run_evaluations(model=agent_model, judge_model=judge_model)
294238

295239
asyncio.run(run_all())

agents_mcp_usage/multi_mcp/mermaid_diagrams.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
```
6464
"""
6565

66-
invalid_mermaid_diagram_easy = """
66+
invalid_mermaid_diagram_medium = """
6767
```mermaid
6868
graph LR
6969
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent
@@ -127,7 +127,71 @@
127127
```
128128
"""
129129

130-
valid_mermaid_diagram = """
130+
invalid_mermaid_diagram_easy = """
131+
```mermaid
132+
graph LR
133+
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent
134+
135+
%% Agent Frameworks
136+
subgraph "Agent Frameworks"
137+
direction TB
138+
Agent[Agent]
139+
ADK["Google ADK<br>(adk_mcp.py)"]
140+
LG["LangGraph<br>(langgraph_mcp.py)"]
141+
OAI["OpenAI Agents<br>(oai-agent_mcp.py)"]
142+
PYD["Pydantic-AI<br>(pydantic_mcp.py)"]
143+
144+
Agent --> ADK
145+
Agent --> LG
146+
Agent --> OAI
147+
Agent --> PYD
148+
end
149+
150+
%% MCP Server
151+
subgraph "MCP Server"
152+
direction TB
153+
MCP["Model Context Protocol Server<br>(run_server.py)"]
154+
Tools["Tools<br>- add(a, b)<br>- get_current_time() e.g. {current_time}"]
155+
Resources["Resources<br>- greeting://{{name}}"]
156+
MCPs --- Tools
157+
MCPs --- Resources
158+
end
159+
160+
subgraph "LLM Providers"
161+
OAI_LLM["OpenAI Models"]
162+
GEM["Google Gemini Models"]
163+
OTHER["Other LLM Providers..."]
164+
end
165+
166+
Logfire[("Logfire<br>Tracing")]
167+
168+
ADK --> MCP
169+
LG --> MCP
170+
OAI --> MCP
171+
PYD --> MCP
172+
173+
MCP --> OAI_LLM
174+
MCP --> GEM
175+
MCP --> OTHER
176+
177+
ADK --> Logfire
178+
LG --> Logfire
179+
OAI --> Logfire
180+
PYD --> Logfire
181+
182+
LLM_Response[("Response")] --> User
183+
OAI_LLM --> LLM_Response
184+
GEM --> LLM_Response
185+
OTHER --> LLM_Response
186+
187+
style MCP fill:#f9f,stroke:#333,stroke-width:2px
188+
style User fill:#bbf,stroke:#338,stroke-width:2px
189+
style Logfire fill:#bfb,stroke:#383,stroke-width:2px
190+
style LLM_Response fill:#fbb,stroke:#833,stroke-width:2px
191+
```
192+
"""
193+
194+
valid_mermaid_diagram = """`
131195
```mermaid
132196
graph LR
133197
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent

0 commit comments

Comments
 (0)