Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 70 additions & 126 deletions agents_mcp_usage/multi_mcp/eval_multi_mcp/evals_pydantic_mcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import subprocess
from typing import Any

import logfire
Expand All @@ -14,8 +13,11 @@

from agents_mcp_usage.multi_mcp.mermaid_diagrams import (
invalid_mermaid_diagram_easy,
invalid_mermaid_diagram_medium,
invalid_mermaid_diagram_hard,
valid_mermaid_diagram,
)
from mcp_servers.mermaid_validator import validate_mermaid_diagram

load_dotenv()

Expand All @@ -27,8 +29,8 @@
logfire.instrument_pydantic_ai()

# Default model to use
DEFAULT_MODEL = "gemini-2.5-pro-preview-03-25"
# DEFAULT_MODEL = "openai:o4-mini"
DEFAULT_MODEL = "gemini-2.5-pro-preview-05-06"

# Configure MCP servers
local_server = MCPServerStdio(
command="uv",
Expand All @@ -39,53 +41,23 @@
],
)
mermaid_server = MCPServerStdio(
command="npx",
command="uv",
args=[
"-y",
"@rtuin/mcp-mermaid-validator@latest",
"run",
"mcp_servers/mermaid_validator.py",
],
)


# Create Agent with MCP servers
def create_agent(model: str = DEFAULT_MODEL):
def create_agent(model: str = DEFAULT_MODEL, model_settings: dict[str, Any] = {}):
return Agent(
model,
mcp_servers=[local_server, mermaid_server],
model_settings=model_settings,
)


agent = create_agent()
Agent.instrument_all()


async def main(
query: str = "Hi!", request_limit: int = 5, model: str = DEFAULT_MODEL
) -> Any:
"""
Main function to run the agent

Args:
query (str): The query to run the agent with
request_limit (int): The number of requests to make to the MCP servers
model (str): The model to use for the agent

Returns:
The result from the agent's execution
"""
# Create a fresh agent with the specified model
current_agent = create_agent(model)

# Set a request limit for LLM calls
usage_limits = UsageLimits(request_limit=request_limit)

# Invoke the agent with the usage limits
async with current_agent.run_mcp_servers():
result = await current_agent.run(query, usage_limits=usage_limits)

return result


# Define input and output schema for evaluations
class MermaidInput(BaseModel):
invalid_diagram: str
Expand All @@ -110,86 +82,35 @@ class MermaidDiagramValid(Evaluator[MermaidInput, MermaidOutput]):
async def evaluate(
self, ctx: EvaluatorContext[MermaidInput, MermaidOutput]
) -> float:
diagram = ctx.output.fixed_diagram

# Extract mermaid code from markdown code block if present
mermaid_code = diagram
if "```mermaid" in diagram and "```" in diagram:
start_idx = diagram.find("```mermaid") + len("```mermaid")
end_idx = diagram.rfind("```")
mermaid_code = diagram[start_idx:end_idx].strip()

# Validate using mmdc
is_valid, _ = self.validate_mermaid_string_via_mmdc(mermaid_code)
return 1.0 if is_valid else 0.0

def validate_mermaid_string_via_mmdc(
self, mermaid_code: str, mmdc_path: str = "mmdc"
) -> tuple[bool, str]:
"""
Validates a Mermaid string by attempting to compile it using the
Mermaid CLI (mmdc). Requires mmdc to be installed and in PATH,
or mmdc_path to be explicitly provided.

Args:
mermaid_code: The string containing the Mermaid diagram syntax.
mmdc_path: The command or path to the mmdc executable.

Returns:
A tuple (is_valid: bool, message: str).
'message' will contain stderr output if not valid, or a success message.
"""
# Define temporary file names
temp_mmd_file = "temp_mermaid_for_validation.mmd"
# mmdc requires an output file, even if we don't use its content for validation.
temp_output_file = "temp_mermaid_output.svg"

# Write the mermaid code to a temporary file
with open(temp_mmd_file, "w", encoding="utf-8") as f:
f.write(mermaid_code)

try:
# Construct the command to run mmdc
command = [mmdc_path, "-i", temp_mmd_file, "-o", temp_output_file]

# Execute the mmdc command
process = subprocess.run(
command,
capture_output=True, # Capture stdout and stderr
text=True, # Decode output as text
check=False, # Do not raise an exception for non-zero exit codes
encoding="utf-8",
# Strip whitespace, remove backticks and ```mermaid markers
input_str = ctx.output.fixed_diagram.strip()

# Remove ```mermaid and ``` markers
if input_str.startswith("```mermaid"):
input_str = input_str[len("```mermaid") :].strip()
if input_str.endswith("```"):
input_str = input_str[:-3].strip()

# Remove any remaining backticks
input_str = input_str.replace("`", "")

logfire.info(
"Evaluating mermaid diagram validity",
diagram_length=len(input_str),
diagram_preview=input_str[:100],
)

# Use the MCP server's validation function
result = await validate_mermaid_diagram(input_str)

if result.is_valid:
logfire.info("Mermaid diagram validation succeeded")
else:
logfire.warning(
"Mermaid diagram validation failed", error_message=result.error_message
)

if process.returncode == 0:
return True, "Syntax appears valid (compiled successfully by mmdc)."
else:
# mmdc usually prints errors to stderr.
error_message = process.stderr.strip()
# Sometimes, syntax errors might also appear in stdout for certain mmdc versions or error types
if not error_message and process.stdout.strip():
error_message = process.stdout.strip()
return (
False,
f"Invalid syntax or mmdc error (exit code {process.returncode}):\n{error_message}",
)
except FileNotFoundError:
return False, (
f"Validation failed: '{mmdc_path}' command not found. "
"Please ensure Mermaid CLI (mmdc) is installed and in your system's PATH, "
"or provide the full path to the executable."
)
except Exception as e:
return (
False,
f"Validation failed due to an unexpected error during mmdc execution: {e}",
)
finally:
# Clean up the temporary files
if os.path.exists(temp_mmd_file):
os.remove(temp_mmd_file)
if os.path.exists(temp_output_file):
os.remove(temp_output_file)
return 1.0 if result.is_valid else 0.0


async def fix_mermaid_diagram(
Expand All @@ -206,9 +127,15 @@ async def fix_mermaid_diagram(
"""
query = f"Add the current time and fix the mermaid diagram syntax using the validator: {inputs.invalid_diagram}. Return only the fixed mermaid diagram between backticks."

result = await main(query, model=model)
# Create a fresh agent for each invocation to avoid concurrent usage issues
current_agent = create_agent(model)
usage_limits = UsageLimits(request_limit=5)

# Extract the mermaid diagram from the output
# Use the agent's context manager directly in this function
async with current_agent.run_mcp_servers():
result = await current_agent.run(query, usage_limits=usage_limits)

# Extract the mermaid diagram from the result output
output = result.output

# Logic to extract the diagram from between backticks
Expand All @@ -232,12 +159,25 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
The evaluation dataset
"""
return Dataset[MermaidInput, MermaidOutput, Any](
# Construct 3 tests, each asks the LLM to fix an invalid mermaid diagram of increasing difficulty
cases=[
Case(
name="fix_invalid_diagram_1",
name="fix_invalid_diagram_easy",
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_easy),
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
metadata={"test_type": "mermaid_easy_fix", "iteration": 1},
metadata={"test_type": "mermaid_easy_fix"},
),
Case(
name="fix_invalid_diagram_medium",
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_medium),
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
metadata={"test_type": "mermaid_medium_fix"},
),
Case(
name="fix_invalid_diagram_hard",
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_hard),
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
metadata={"test_type": "mermaid_hard_fix"},
),
],
evaluators=[
Expand All @@ -249,9 +189,9 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
model=judge_model,
),
LLMJudge(
rubric="The fixed diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
rubric="The output diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
+ "Check if nodes, connections, and labels are preserved."
+ "The current time should be placeholder should be replace with a datetime",
+ "The current time should be placeholder should be replace with a valid datetime",
include_input=False,
model=judge_model,
),
Expand All @@ -276,20 +216,24 @@ async def fix_with_model(inputs: MermaidInput) -> MermaidOutput:
return await fix_mermaid_diagram(inputs, model=model)

report = await dataset.evaluate(
fix_with_model, name=f"{model}-multi-mcp-mermaid-diagram-fix-evals"
fix_with_model,
name=f"{model}-multi-mcp-mermaid-diagram-fix-evals",
max_concurrency=1, # Run one evaluation at a time
)

report.print(include_input=True, include_output=True)
report.print(include_input=False, include_output=False)
return report


if __name__ == "__main__":
# You can use different models for the agent and the judge
agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
# agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
agent_model = "gemini-2.5-flash-preview-04-17"
# agent_model = "openai:o4-mini"
# agent_model = "gemini-2.5-flash-preview-04-17"
judge_model = os.getenv("JUDGE_MODEL", DEFAULT_MODEL)

async def run_all():
# Run evaluations
await run_evaluations(model=agent_model, judge_model=judge_model)

asyncio.run(run_all())
68 changes: 66 additions & 2 deletions agents_mcp_usage/multi_mcp/mermaid_diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
```
"""

invalid_mermaid_diagram_easy = """
invalid_mermaid_diagram_medium = """
```mermaid
graph LR
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent
Expand Down Expand Up @@ -127,7 +127,71 @@
```
"""

valid_mermaid_diagram = """
invalid_mermaid_diagram_easy = """
```mermaid
graph LR
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent

%% Agent Frameworks
subgraph "Agent Frameworks"
direction TB
Agent[Agent]
ADK["Google ADK<br>(adk_mcp.py)"]
LG["LangGraph<br>(langgraph_mcp.py)"]
OAI["OpenAI Agents<br>(oai-agent_mcp.py)"]
PYD["Pydantic-AI<br>(pydantic_mcp.py)"]

Agent --> ADK
Agent --> LG
Agent --> OAI
Agent --> PYD
end

%% MCP Server
subgraph "MCP Server"
direction TB
MCP["Model Context Protocol Server<br>(run_server.py)"]
Tools["Tools<br>- add(a, b)<br>- get_current_time() e.g. {current_time}"]
Resources["Resources<br>- greeting://{{name}}"]
MCPs --- Tools
MCPs --- Resources
end

subgraph "LLM Providers"
OAI_LLM["OpenAI Models"]
GEM["Google Gemini Models"]
OTHER["Other LLM Providers..."]
end

Logfire[("Logfire<br>Tracing")]

ADK --> MCP
LG --> MCP
OAI --> MCP
PYD --> MCP

MCP --> OAI_LLM
MCP --> GEM
MCP --> OTHER

ADK --> Logfire
LG --> Logfire
OAI --> Logfire
PYD --> Logfire

LLM_Response[("Response")] --> User
OAI_LLM --> LLM_Response
GEM --> LLM_Response
OTHER --> LLM_Response

style MCP fill:#f9f,stroke:#333,stroke-width:2px
style User fill:#bbf,stroke:#338,stroke-width:2px
style Logfire fill:#bfb,stroke:#383,stroke-width:2px
style LLM_Response fill:#fbb,stroke:#833,stroke-width:2px
```
"""

valid_mermaid_diagram = """`
```mermaid
graph LR
User((User)) --> |"Run script<br>(e.g., pydantic_mcp.py)"| Agent
Expand Down
Loading