Skip to content

Commit 858ee33

Browse files
committed
feat: Multi-mcp Agent run and evals
1 parent 5637414 commit 858ee33

File tree

5 files changed

+676
-1
lines changed

5 files changed

+676
-1
lines changed

agents_mcp_usage/multi_mcp/__init__.py

Whitespace-only changes.
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
import asyncio
2+
import os
3+
import subprocess
4+
from typing import Any
5+
6+
import logfire
7+
from dotenv import load_dotenv
8+
from pydantic import BaseModel
9+
from pydantic_ai import Agent
10+
from pydantic_ai.mcp import MCPServerStdio
11+
from pydantic_ai.usage import UsageLimits
12+
from pydantic_evals import Case, Dataset
13+
from pydantic_evals.evaluators import Evaluator, EvaluatorContext, IsInstance, LLMJudge
14+
15+
from agents_mcp_usage.multi_mcp.mermaid_diagrams import (
16+
invalid_mermaid_diagram_easy,
17+
valid_mermaid_diagram,
18+
)
19+
20+
load_dotenv()
21+
22+
# Configure logging to logfire if LOGFIRE_TOKEN is set in environment
23+
logfire.configure(
24+
send_to_logfire="if-token-present", service_name="evals-pydantic-multi-mcp"
25+
)
26+
logfire.instrument_mcp()
27+
logfire.instrument_pydantic_ai()
28+
29+
# Default model to use
30+
DEFAULT_MODEL = "gemini-2.5-pro-preview-03-25"
31+
# DEFAULT_MODEL = "openai:o4-mini"
32+
# Configure MCP servers
33+
local_server = MCPServerStdio(
34+
command="uv",
35+
args=[
36+
"run",
37+
"run_server.py",
38+
"stdio",
39+
],
40+
)
41+
mermaid_server = MCPServerStdio(
42+
command="npx",
43+
args=[
44+
"-y",
45+
"@rtuin/mcp-mermaid-validator@latest",
46+
],
47+
)
48+
49+
50+
# Create Agent with MCP servers
51+
def create_agent(model: str = DEFAULT_MODEL):
52+
return Agent(
53+
model,
54+
mcp_servers=[local_server, mermaid_server],
55+
)
56+
57+
58+
agent = create_agent()
59+
Agent.instrument_all()
60+
61+
62+
async def main(
63+
query: str = "Hi!", request_limit: int = 5, model: str = DEFAULT_MODEL
64+
) -> Any:
65+
"""
66+
Main function to run the agent
67+
68+
Args:
69+
query (str): The query to run the agent with
70+
request_limit (int): The number of requests to make to the MCP servers
71+
model (str): The model to use for the agent
72+
73+
Returns:
74+
The result from the agent's execution
75+
"""
76+
# Create a fresh agent with the specified model
77+
current_agent = create_agent(model)
78+
79+
# Set a request limit for LLM calls
80+
usage_limits = UsageLimits(request_limit=request_limit)
81+
82+
# Invoke the agent with the usage limits
83+
async with current_agent.run_mcp_servers():
84+
result = await current_agent.run(query, usage_limits=usage_limits)
85+
86+
return result
87+
88+
89+
# Define input and output schema for evaluations
90+
class MermaidInput(BaseModel):
91+
invalid_diagram: str
92+
93+
94+
class MermaidOutput(BaseModel):
95+
fixed_diagram: str
96+
97+
98+
# Custom evaluator to check if both MCP tools were used
99+
class UsedBothMCPTools(Evaluator[MermaidInput, MermaidOutput]):
100+
async def evaluate(
101+
self, ctx: EvaluatorContext[MermaidInput, MermaidOutput]
102+
) -> float:
103+
# In a real implementation, we would check logs to verify both servers were used
104+
# For now, we'll assume success if we get a valid diagram output
105+
return 1.0 if ctx.output and ctx.output.fixed_diagram else 0.0
106+
107+
108+
# Custom evaluator to check if the mermaid diagram is valid
109+
class MermaidDiagramValid(Evaluator[MermaidInput, MermaidOutput]):
110+
async def evaluate(
111+
self, ctx: EvaluatorContext[MermaidInput, MermaidOutput]
112+
) -> float:
113+
diagram = ctx.output.fixed_diagram
114+
115+
# Extract mermaid code from markdown code block if present
116+
mermaid_code = diagram
117+
if "```mermaid" in diagram and "```" in diagram:
118+
start_idx = diagram.find("```mermaid") + len("```mermaid")
119+
end_idx = diagram.rfind("```")
120+
mermaid_code = diagram[start_idx:end_idx].strip()
121+
122+
# Validate using mmdc
123+
is_valid, _ = self.validate_mermaid_string_via_mmdc(mermaid_code)
124+
return 1.0 if is_valid else 0.0
125+
126+
def validate_mermaid_string_via_mmdc(
127+
self, mermaid_code: str, mmdc_path: str = "mmdc"
128+
) -> tuple[bool, str]:
129+
"""
130+
Validates a Mermaid string by attempting to compile it using the
131+
Mermaid CLI (mmdc). Requires mmdc to be installed and in PATH,
132+
or mmdc_path to be explicitly provided.
133+
134+
Args:
135+
mermaid_code: The string containing the Mermaid diagram syntax.
136+
mmdc_path: The command or path to the mmdc executable.
137+
138+
Returns:
139+
A tuple (is_valid: bool, message: str).
140+
'message' will contain stderr output if not valid, or a success message.
141+
"""
142+
# Define temporary file names
143+
temp_mmd_file = "temp_mermaid_for_validation.mmd"
144+
# mmdc requires an output file, even if we don't use its content for validation.
145+
temp_output_file = "temp_mermaid_output.svg"
146+
147+
# Write the mermaid code to a temporary file
148+
with open(temp_mmd_file, "w", encoding="utf-8") as f:
149+
f.write(mermaid_code)
150+
151+
try:
152+
# Construct the command to run mmdc
153+
command = [mmdc_path, "-i", temp_mmd_file, "-o", temp_output_file]
154+
155+
# Execute the mmdc command
156+
process = subprocess.run(
157+
command,
158+
capture_output=True, # Capture stdout and stderr
159+
text=True, # Decode output as text
160+
check=False, # Do not raise an exception for non-zero exit codes
161+
encoding="utf-8",
162+
)
163+
164+
if process.returncode == 0:
165+
return True, "Syntax appears valid (compiled successfully by mmdc)."
166+
else:
167+
# mmdc usually prints errors to stderr.
168+
error_message = process.stderr.strip()
169+
# Sometimes, syntax errors might also appear in stdout for certain mmdc versions or error types
170+
if not error_message and process.stdout.strip():
171+
error_message = process.stdout.strip()
172+
return (
173+
False,
174+
f"Invalid syntax or mmdc error (exit code {process.returncode}):\n{error_message}",
175+
)
176+
except FileNotFoundError:
177+
return False, (
178+
f"Validation failed: '{mmdc_path}' command not found. "
179+
"Please ensure Mermaid CLI (mmdc) is installed and in your system's PATH, "
180+
"or provide the full path to the executable."
181+
)
182+
except Exception as e:
183+
return (
184+
False,
185+
f"Validation failed due to an unexpected error during mmdc execution: {e}",
186+
)
187+
finally:
188+
# Clean up the temporary files
189+
if os.path.exists(temp_mmd_file):
190+
os.remove(temp_mmd_file)
191+
if os.path.exists(temp_output_file):
192+
os.remove(temp_output_file)
193+
194+
195+
async def fix_mermaid_diagram(
196+
inputs: MermaidInput, model: str = DEFAULT_MODEL
197+
) -> MermaidOutput:
198+
"""Fix an invalid mermaid diagram using the agent with multiple MCP servers.
199+
200+
Args:
201+
inputs: The input containing the invalid diagram
202+
model: The model to use for the agent
203+
204+
Returns:
205+
MermaidOutput with the fixed diagram
206+
"""
207+
query = f"Add the current time and fix the mermaid diagram syntax using the validator: {inputs.invalid_diagram}. Return only the fixed mermaid diagram between backticks."
208+
209+
result = await main(query, model=model)
210+
211+
# Extract the mermaid diagram from the output
212+
output = result.output
213+
214+
# Logic to extract the diagram from between backticks
215+
if "```" in output:
216+
start = output.find("```")
217+
end = output.rfind("```") + 3
218+
diagram = output[start:end]
219+
else:
220+
diagram = output
221+
222+
return MermaidOutput(fixed_diagram=diagram)
223+
224+
225+
def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
226+
"""Create the dataset for evaluating mermaid diagram fixing.
227+
228+
Args:
229+
judge_model: The model to use for LLM judging
230+
231+
Returns:
232+
The evaluation dataset
233+
"""
234+
return Dataset[MermaidInput, MermaidOutput, Any](
235+
cases=[
236+
Case(
237+
name="fix_invalid_diagram_1",
238+
inputs=MermaidInput(invalid_diagram=invalid_mermaid_diagram_easy),
239+
expected_output=MermaidOutput(fixed_diagram=valid_mermaid_diagram),
240+
metadata={"test_type": "mermaid_easy_fix", "iteration": 1},
241+
),
242+
],
243+
evaluators=[
244+
UsedBothMCPTools(),
245+
MermaidDiagramValid(),
246+
LLMJudge(
247+
rubric="The response only contains a mermaid diagram, no other text.",
248+
include_input=False,
249+
model=judge_model,
250+
),
251+
LLMJudge(
252+
rubric="The fixed diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
253+
+ "Check if nodes, connections, and labels are preserved."
254+
+ "The current time should be placeholder should be replace with a datetime",
255+
include_input=False,
256+
model=judge_model,
257+
),
258+
],
259+
)
260+
261+
262+
async def run_evaluations(model: str = DEFAULT_MODEL, judge_model: str = DEFAULT_MODEL):
263+
"""Run the evaluations on the mermaid diagram fixing task.
264+
265+
Args:
266+
model: The model to use for the agent
267+
judge_model: The model to use for LLM judging
268+
269+
Returns:
270+
The evaluation report
271+
"""
272+
dataset = create_evaluation_dataset(judge_model)
273+
274+
# Create a wrapper that includes the model parameter
275+
async def fix_with_model(inputs: MermaidInput) -> MermaidOutput:
276+
return await fix_mermaid_diagram(inputs, model=model)
277+
278+
report = await dataset.evaluate(
279+
fix_with_model, name=f"{model}-multi-mcp-mermaid-diagram-fix-evals"
280+
)
281+
282+
report.print(include_input=True, include_output=True)
283+
return report
284+
285+
286+
if __name__ == "__main__":
287+
# You can use different models for the agent and the judge
288+
agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
289+
judge_model = os.getenv("JUDGE_MODEL", DEFAULT_MODEL)
290+
291+
async def run_all():
292+
# Run evaluations
293+
await run_evaluations(model=agent_model, judge_model=judge_model)
294+
295+
asyncio.run(run_all())

0 commit comments

Comments
 (0)