1
1
import asyncio
2
2
import os
3
- import subprocess
4
3
from typing import Any
5
4
6
5
import logfire
14
13
15
14
from agents_mcp_usage .multi_mcp .mermaid_diagrams import (
16
15
invalid_mermaid_diagram_easy ,
16
+ invalid_mermaid_diagram_medium ,
17
+ invalid_mermaid_diagram_hard ,
17
18
valid_mermaid_diagram ,
18
19
)
20
+ from mcp_servers .mermaid_validator import validate_mermaid_diagram
19
21
20
22
load_dotenv ()
21
23
27
29
logfire .instrument_pydantic_ai ()
28
30
29
31
# Default model to use
30
- DEFAULT_MODEL = "gemini-2.5-pro-preview-03-25 "
31
- # DEFAULT_MODEL = "openai:o4-mini"
32
+ DEFAULT_MODEL = "gemini-2.5-pro-preview-05-06 "
33
+
32
34
# Configure MCP servers
33
35
local_server = MCPServerStdio (
34
36
command = "uv" ,
39
41
],
40
42
)
41
43
mermaid_server = MCPServerStdio (
42
- command = "npx " ,
44
+ command = "uv " ,
43
45
args = [
44
- "-y " ,
45
- "@rtuin/mcp-mermaid-validator@latest " ,
46
+ "run " ,
47
+ "mcp_servers/mermaid_validator.py " ,
46
48
],
47
49
)
48
50
49
51
50
52
# Create Agent with MCP servers
51
- def create_agent (model : str = DEFAULT_MODEL ):
53
+ def create_agent (model : str = DEFAULT_MODEL , model_settings : dict [ str , Any ] = {} ):
52
54
return Agent (
53
55
model ,
54
56
mcp_servers = [local_server , mermaid_server ],
57
+ model_settings = model_settings ,
55
58
)
56
59
57
60
58
- agent = create_agent ()
59
- Agent .instrument_all ()
60
-
61
-
62
- async def main (
63
- query : str = "Hi!" , request_limit : int = 5 , model : str = DEFAULT_MODEL
64
- ) -> Any :
65
- """
66
- Main function to run the agent
67
-
68
- Args:
69
- query (str): The query to run the agent with
70
- request_limit (int): The number of requests to make to the MCP servers
71
- model (str): The model to use for the agent
72
-
73
- Returns:
74
- The result from the agent's execution
75
- """
76
- # Create a fresh agent with the specified model
77
- current_agent = create_agent (model )
78
-
79
- # Set a request limit for LLM calls
80
- usage_limits = UsageLimits (request_limit = request_limit )
81
-
82
- # Invoke the agent with the usage limits
83
- async with current_agent .run_mcp_servers ():
84
- result = await current_agent .run (query , usage_limits = usage_limits )
85
-
86
- return result
87
-
88
-
89
61
# Define input and output schema for evaluations
90
62
class MermaidInput (BaseModel ):
91
63
invalid_diagram : str
@@ -110,86 +82,35 @@ class MermaidDiagramValid(Evaluator[MermaidInput, MermaidOutput]):
110
82
async def evaluate (
111
83
self , ctx : EvaluatorContext [MermaidInput , MermaidOutput ]
112
84
) -> float :
113
- diagram = ctx .output .fixed_diagram
114
-
115
- # Extract mermaid code from markdown code block if present
116
- mermaid_code = diagram
117
- if "```mermaid" in diagram and "```" in diagram :
118
- start_idx = diagram .find ("```mermaid" ) + len ("```mermaid" )
119
- end_idx = diagram .rfind ("```" )
120
- mermaid_code = diagram [start_idx :end_idx ].strip ()
121
-
122
- # Validate using mmdc
123
- is_valid , _ = self .validate_mermaid_string_via_mmdc (mermaid_code )
124
- return 1.0 if is_valid else 0.0
125
-
126
- def validate_mermaid_string_via_mmdc (
127
- self , mermaid_code : str , mmdc_path : str = "mmdc"
128
- ) -> tuple [bool , str ]:
129
- """
130
- Validates a Mermaid string by attempting to compile it using the
131
- Mermaid CLI (mmdc). Requires mmdc to be installed and in PATH,
132
- or mmdc_path to be explicitly provided.
133
-
134
- Args:
135
- mermaid_code: The string containing the Mermaid diagram syntax.
136
- mmdc_path: The command or path to the mmdc executable.
137
-
138
- Returns:
139
- A tuple (is_valid: bool, message: str).
140
- 'message' will contain stderr output if not valid, or a success message.
141
- """
142
- # Define temporary file names
143
- temp_mmd_file = "temp_mermaid_for_validation.mmd"
144
- # mmdc requires an output file, even if we don't use its content for validation.
145
- temp_output_file = "temp_mermaid_output.svg"
146
-
147
- # Write the mermaid code to a temporary file
148
- with open (temp_mmd_file , "w" , encoding = "utf-8" ) as f :
149
- f .write (mermaid_code )
150
-
151
- try :
152
- # Construct the command to run mmdc
153
- command = [mmdc_path , "-i" , temp_mmd_file , "-o" , temp_output_file ]
154
-
155
- # Execute the mmdc command
156
- process = subprocess .run (
157
- command ,
158
- capture_output = True , # Capture stdout and stderr
159
- text = True , # Decode output as text
160
- check = False , # Do not raise an exception for non-zero exit codes
161
- encoding = "utf-8" ,
85
+ # Strip whitespace, remove backticks and ```mermaid markers
86
+ input_str = ctx .output .fixed_diagram .strip ()
87
+
88
+ # Remove ```mermaid and ``` markers
89
+ if input_str .startswith ("```mermaid" ):
90
+ input_str = input_str [len ("```mermaid" ) :].strip ()
91
+ if input_str .endswith ("```" ):
92
+ input_str = input_str [:- 3 ].strip ()
93
+
94
+ # Remove any remaining backticks
95
+ input_str = input_str .replace ("`" , "" )
96
+
97
+ logfire .info (
98
+ "Evaluating mermaid diagram validity" ,
99
+ diagram_length = len (input_str ),
100
+ diagram_preview = input_str [:100 ],
101
+ )
102
+
103
+ # Use the MCP server's validation function
104
+ result = await validate_mermaid_diagram (input_str )
105
+
106
+ if result .is_valid :
107
+ logfire .info ("Mermaid diagram validation succeeded" )
108
+ else :
109
+ logfire .warning (
110
+ "Mermaid diagram validation failed" , error_message = result .error_message
162
111
)
163
112
164
- if process .returncode == 0 :
165
- return True , "Syntax appears valid (compiled successfully by mmdc)."
166
- else :
167
- # mmdc usually prints errors to stderr.
168
- error_message = process .stderr .strip ()
169
- # Sometimes, syntax errors might also appear in stdout for certain mmdc versions or error types
170
- if not error_message and process .stdout .strip ():
171
- error_message = process .stdout .strip ()
172
- return (
173
- False ,
174
- f"Invalid syntax or mmdc error (exit code { process .returncode } ):\n { error_message } " ,
175
- )
176
- except FileNotFoundError :
177
- return False , (
178
- f"Validation failed: '{ mmdc_path } ' command not found. "
179
- "Please ensure Mermaid CLI (mmdc) is installed and in your system's PATH, "
180
- "or provide the full path to the executable."
181
- )
182
- except Exception as e :
183
- return (
184
- False ,
185
- f"Validation failed due to an unexpected error during mmdc execution: { e } " ,
186
- )
187
- finally :
188
- # Clean up the temporary files
189
- if os .path .exists (temp_mmd_file ):
190
- os .remove (temp_mmd_file )
191
- if os .path .exists (temp_output_file ):
192
- os .remove (temp_output_file )
113
+ return 1.0 if result .is_valid else 0.0
193
114
194
115
195
116
async def fix_mermaid_diagram (
@@ -206,9 +127,15 @@ async def fix_mermaid_diagram(
206
127
"""
207
128
query = f"Add the current time and fix the mermaid diagram syntax using the validator: { inputs .invalid_diagram } . Return only the fixed mermaid diagram between backticks."
208
129
209
- result = await main (query , model = model )
130
+ # Create a fresh agent for each invocation to avoid concurrent usage issues
131
+ current_agent = create_agent (model )
132
+ usage_limits = UsageLimits (request_limit = 5 )
210
133
211
- # Extract the mermaid diagram from the output
134
+ # Use the agent's context manager directly in this function
135
+ async with current_agent .run_mcp_servers ():
136
+ result = await current_agent .run (query , usage_limits = usage_limits )
137
+
138
+ # Extract the mermaid diagram from the result output
212
139
output = result .output
213
140
214
141
# Logic to extract the diagram from between backticks
@@ -232,12 +159,25 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
232
159
The evaluation dataset
233
160
"""
234
161
return Dataset [MermaidInput , MermaidOutput , Any ](
162
+ # Construct 3 tests, each asks the LLM to fix an invalid mermaid diagram of increasing difficulty
235
163
cases = [
236
164
Case (
237
- name = "fix_invalid_diagram_1 " ,
165
+ name = "fix_invalid_diagram_easy " ,
238
166
inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_easy ),
239
167
expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
240
- metadata = {"test_type" : "mermaid_easy_fix" , "iteration" : 1 },
168
+ metadata = {"test_type" : "mermaid_easy_fix" },
169
+ ),
170
+ Case (
171
+ name = "fix_invalid_diagram_medium" ,
172
+ inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_medium ),
173
+ expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
174
+ metadata = {"test_type" : "mermaid_medium_fix" },
175
+ ),
176
+ Case (
177
+ name = "fix_invalid_diagram_hard" ,
178
+ inputs = MermaidInput (invalid_diagram = invalid_mermaid_diagram_hard ),
179
+ expected_output = MermaidOutput (fixed_diagram = valid_mermaid_diagram ),
180
+ metadata = {"test_type" : "mermaid_hard_fix" },
241
181
),
242
182
],
243
183
evaluators = [
@@ -249,9 +189,9 @@ def create_evaluation_dataset(judge_model: str = DEFAULT_MODEL):
249
189
model = judge_model ,
250
190
),
251
191
LLMJudge (
252
- rubric = "The fixed diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
192
+ rubric = "The output diagram should maintain the same overall structure and intent as the expected output diagram while fixing any syntax errors."
253
193
+ "Check if nodes, connections, and labels are preserved."
254
- + "The current time should be placeholder should be replace with a datetime" ,
194
+ + "The current time should be placeholder should be replace with a valid datetime" ,
255
195
include_input = False ,
256
196
model = judge_model ,
257
197
),
@@ -276,20 +216,24 @@ async def fix_with_model(inputs: MermaidInput) -> MermaidOutput:
276
216
return await fix_mermaid_diagram (inputs , model = model )
277
217
278
218
report = await dataset .evaluate (
279
- fix_with_model , name = f"{ model } -multi-mcp-mermaid-diagram-fix-evals"
219
+ fix_with_model ,
220
+ name = f"{ model } -multi-mcp-mermaid-diagram-fix-evals" ,
221
+ max_concurrency = 1 , # Run one evaluation at a time
280
222
)
281
223
282
- report .print (include_input = True , include_output = True )
224
+ report .print (include_input = False , include_output = False )
283
225
return report
284
226
285
227
286
228
if __name__ == "__main__" :
287
229
# You can use different models for the agent and the judge
288
- agent_model = os .getenv ("AGENT_MODEL" , DEFAULT_MODEL )
230
+ # agent_model = os.getenv("AGENT_MODEL", DEFAULT_MODEL)
231
+ agent_model = "gemini-2.5-flash-preview-04-17"
232
+ # agent_model = "openai:o4-mini"
233
+ # agent_model = "gemini-2.5-flash-preview-04-17"
289
234
judge_model = os .getenv ("JUDGE_MODEL" , DEFAULT_MODEL )
290
235
291
236
async def run_all ():
292
- # Run evaluations
293
237
await run_evaluations (model = agent_model , judge_model = judge_model )
294
238
295
239
asyncio .run (run_all ())
0 commit comments