Skip to content

Commit 09f700a

Browse files
mcp: refactor and extend agent loop (google#14135)
- Use a stdio server instead of http, simplifying how we run the server - Enable multiple runs for fixing builds if first chat run didn't work - Remove misleading comments Signed-off-by: David Korczynski <[email protected]>
1 parent 37d1fcd commit 09f700a

File tree

2 files changed

+51
-54
lines changed

2 files changed

+51
-54
lines changed

infra/experimental/mcp/client.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,20 @@
2323
import subprocess
2424
import argparse
2525
import time
26-
26+
import sys
2727
import httpx
28-
from pydantic_ai import Agent
29-
from pydantic_ai.mcp import MCPServerSSE
28+
import pathlib
3029

31-
import logfire
30+
from pydantic_ai import Agent
31+
from pydantic_ai.mcp import MCPServerStdio
3232

3333
import config as oss_fuzz_mcp_config
3434

35-
logfire.configure(send_to_logfire='if-token-present')
36-
logfire.instrument_pydantic_ai()
37-
3835
# Configure logging
3936
logging.basicConfig(
4037
level=logging.INFO,
41-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
38+
format="[CLIENT] %(asctime)s - %(name)s - %(levelname)s - %(message)s",
39+
stream=sys.stderr)
4240
logger = logging.getLogger("mcp-server")
4341

4442
MCP_SERVER_URL = "http://localhost:8000/sse"
@@ -178,9 +176,9 @@
178176
"""
179177

180178

181-
async def chat_with_agent(prompt: str) -> list:
179+
async def run_agent_loop(prompt: str) -> list:
182180
"""
183-
Send a message to the LLM with access to the MCP tools.
181+
Performs a run with the LLM.
184182
185183
Args:
186184
prompt: The user's message
@@ -190,23 +188,24 @@ async def chat_with_agent(prompt: str) -> list:
190188
"""
191189
nodes = []
192190
try:
193-
server = MCPServerSSE(url=MCP_SERVER_URL,
194-
timeout=5200.0,
195-
read_timeout=5000.0)
191+
server = MCPServerStdio(
192+
'python3',
193+
[str(pathlib.Path(__file__).parent.resolve()) + '/oss_fuzz_server.py'],
194+
timeout=5200.0)
196195

197-
agent = Agent(model="openai:gpt-4o", toolsets=[server], retries=30)
196+
agent = Agent(model="openai:gpt-4", toolsets=[server], retries=30)
198197

199198
# Run the agent with the MCP server context
200199
logger.info('Starting agent run')
201200
async with agent.iter(prompt) as agent_run:
202201
logger.info('Agent run started')
203-
204202
async for node in agent_run:
205-
logger.info('Running node %s', node.__class__.__name__)
203+
logger.info('Running node [%d] %s', len(nodes), node.__class__.__name__)
206204
time.sleep(3)
207205
nodes.append(node)
208206
except Exception as e:
209207
logger.info('Error during agent run: %s', e)
208+
sys.exit(1)
210209

211210
return nodes
212211

@@ -365,7 +364,7 @@ async def does_project_build(project: str) -> bool:
365364
return True
366365

367366

368-
async def fix_project_build(project: str):
367+
async def fix_project_build(project: str, max_tries: int = 3):
369368
"""Runs an agent to fix the build of an OSS-Fuzz project."""
370369

371370
project_language = _detect_language(project)
@@ -382,8 +381,11 @@ async def fix_project_build(project: str):
382381
if oss_fuzz_filetree:
383382
extra_project_text += f'The files in the OSS-Fuzz project for {project} are:\n{oss_fuzz_filetree}\n'
384383

385-
nodes = await chat_with_agent(
386-
f"""Fix the OSS-Fuzz project {project} that currently has a broken build.
384+
nodes = []
385+
for _attempt in range(max_tries):
386+
logger.info('Attempt %d to fix project %s', _attempt + 1, project)
387+
nodes += await run_agent_loop(
388+
f"""Fix the OSS-Fuzz project {project} that currently has a broken build.
387389
Use the build logs from OSS-Fuzz's project {project} and determine why it fails, then
388390
proceed to adjust Dockerfile and build.sh scripts until the project builds.
389391
@@ -410,7 +412,10 @@ async def fix_project_build(project: str):
410412
- Continue adjusting the files in {oss_fuzz_mcp_config.BASE_OSS_FUZZ_DIR}/projects/{project}/ until "fuzzer-check" passes.
411413
""")
412414

413-
fix_success = await does_project_build(project)
415+
fix_success = await does_project_build(project)
416+
if fix_success:
417+
logger.info('Project %s build fixed successfully.', project)
418+
break
414419
return nodes, fix_success
415420

416421

@@ -501,7 +506,7 @@ async def add_run_tests_command(project_name: str):
501506

502507
os.chmod(run_tests_path, 0o755)
503508

504-
await chat_with_agent(f"""
509+
await run_agent_loop(f"""
505510
You are an expert software security engineer that is specialized in OSS-Fuzz.
506511
You are tasked with adding a run_tests.sh script to an OSS-Fuzz project.
507512
This script should run the tests of the project, and ensure that the project is working correctly.
@@ -530,7 +535,7 @@ async def expand_existing_project(project_name: str):
530535
logger.info('Failed to prepare %s. Exiting.', project_name)
531536
return
532537

533-
nodes = await chat_with_agent(
538+
nodes = await run_agent_loop(
534539
f"""You are a security engineer that is an expert in fuzzing development, and your goal is to expand on the
535540
fuzzing harnesses of OSS-Fuzz project {project_name}.
536541
Use the tools to understand the fuzzing harnesses of the {project_name}'s OSS-Fuzz integration.
@@ -571,7 +576,8 @@ def _log_nodes(logfile, nodes, header_text=''):
571576

572577
async def fix_oss_fuzz_projects(projects_to_fix=None,
573578
max_projects_to_fix=4,
574-
language=''):
579+
language='',
580+
max_tries=3):
575581
"""Fixes the build of a list of OSS-Fuzz projects."""
576582

577583
if projects_to_fix is None:
@@ -598,7 +604,7 @@ async def fix_oss_fuzz_projects(projects_to_fix=None,
598604
continue
599605
except:
600606
continue
601-
nodes, fix_success = await fix_project_build(project)
607+
nodes, fix_success = await fix_project_build(project, max_tries)
602608
responses.append({'project': project, 'fix_success': fix_success})
603609
if nodes:
604610
_log_nodes(f'responses-fix-build-{project}.json',
@@ -707,7 +713,7 @@ async def initiate_project_creation(project: str, project_repo: str,
707713
available to extract code coverage of the project when you're creating the harness, and either
708714
add more fuzzing harnesses to the project or extend the harness to cover more functions."""
709715

710-
nodes = await chat_with_agent(
716+
nodes = await run_agent_loop(
711717
f"""You are an expert software security engineer and you are tasked with creating an OSS-Fuzz project.
712718
I have set up an initial project structure at {oss_fuzz_mcp_config.BASE_OSS_FUZZ_DIR}/projects/{project}/. This structure
713719
includes a Dockerfile, build.sh, and project.yaml file. The Dockerfile clones the target
@@ -882,6 +888,11 @@ def parse_arguments():
882888
fix_builds = subparsers.add_parser(
883889
'fix-builds',
884890
help='Fix the builds of OSS-Fuzz projects that are currently broken.')
891+
fix_builds.add_argument(
892+
'--max_attempts',
893+
type=int,
894+
default=3,
895+
help='Maximum number of attempts to fix each project (default: 3)')
885896

886897
fix_builds.add_argument('--max-projects',
887898
type=int,
@@ -939,7 +950,8 @@ async def main():
939950

940951
initialize_oss_fuzz()
941952
if args.command == 'fix-builds':
942-
await fix_oss_fuzz_projects(args.projects, args.max_projects, args.language)
953+
await fix_oss_fuzz_projects(args.projects, args.max_projects, args.language,
954+
args.max_attempts)
943955
elif args.command == 'create-project':
944956
logger.info('Creating OSS-Fuzz project for URL: %s', args.project_url)
945957
await create_oss_fuzz_integration_for_project(args.project_url,

infra/experimental/mcp/oss_fuzz_server.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,24 @@
1818
import logging
1919
import asyncio
2020
import os
21+
import sys
2122
import json
2223
import time
2324
import subprocess
24-
25-
import httpx
2625
from mcp.server.fastmcp import FastMCP
2726

2827
import config as oss_fuzz_mcp_config
2928

3029
# Configure logging
3130
logging.basicConfig(
3231
level=logging.INFO,
33-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32+
format=
33+
"[SERVER] %(asctime)s - %(name)s - '%(module)s - %(funcName)s - %(levelname)s - %(message)s",
34+
stream=sys.stderr)
3435
logger = logging.getLogger("mcp-server")
3536

36-
# Create a shared HTTP client for API requests
37-
http_client = httpx.AsyncClient(timeout=10.0)
38-
3937
# Create an MCP server with a name
40-
mcp = FastMCP("OSS-Fuzz tools with relevant file system utilities.",
41-
host="0.0.0.0",
42-
port=8000)
38+
mcp = FastMCP("OSS-Fuzz tools with relevant file system utilities.")
4339

4440
FILE_ACCESS_ERROR = f"""Error: Cannot access directories outside of the base directory.
4541
Remember, all paths accessible by you must be prefixed with {oss_fuzz_mcp_config.BASE_DIR}.
@@ -96,14 +92,16 @@ async def check_if_oss_fuzz_project_builds(project_name: str) -> bool:
9692
project_name,
9793
cwd=oss_fuzz_mcp_config.BASE_OSS_FUZZ_DIR,
9894
shell=True,
95+
stdout=subprocess.DEVNULL,
96+
stderr=subprocess.STDOUT,
9997
timeout=60 * 20)
10098
return True
10199
except subprocess.CalledProcessError as e:
102100
logger.info("Build failed for project '%s': {%s}", project_name, str(e))
103101
return False
104102
except subprocess.TimeoutExpired:
105103
logger.info(f"Building project {project_name} timed out.")
106-
return False
104+
return False
107105

108106

109107
def shorten_logs_if_needed(log_string: str) -> str:
@@ -227,7 +225,6 @@ async def check_run_tests(
227225
Returns:
228226
The logs from building the project with custom artifacts.
229227
"""
230-
logger.info('Running test 1')
231228
clone_oss_fuzz_if_it_does_not_exist()
232229
logger.info(
233230
"Checking if OSS-Fuzz project '%s' builds with custom artifacts...",
@@ -236,12 +233,10 @@ async def check_run_tests(
236233
os.makedirs(oss_fuzz_mcp_config.BASE_TMP_LOGS, exist_ok=True)
237234
target_logs = os.path.join(oss_fuzz_mcp_config.BASE_TMP_LOGS,
238235
'check-fuzz-run-tests-log.txt')
239-
logger.info('Running test 2')
240236
if os.path.isfile(target_logs):
241237
os.remove(target_logs)
242238
log_stdout = open(target_logs, 'w', encoding='utf-8')
243239
try:
244-
logger.info('Running test 3')
245240
subprocess.check_call(
246241
f'infra/experimental/chronos/check_tests.sh {project_name} c++',
247242
cwd=oss_fuzz_mcp_config.BASE_OSS_FUZZ_DIR,
@@ -251,7 +246,6 @@ async def check_run_tests(
251246
timeout=60 * 20)
252247

253248
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
254-
logger.info('Running test 4')
255249
logger.info("Build failed for project '%s': {%s}", project_name, str(e))
256250
log_stdout.write("\n\nrun-tests.sh failed!!\n")
257251
with open(target_logs, 'r', encoding='utf-8') as f:
@@ -260,14 +254,12 @@ async def check_run_tests(
260254
logger.info("run-tests.sh logs for project '%s': {%s}", project_name,
261255
logs_to_return)
262256
return logs_to_return
263-
logger.info('Running test 5.01')
257+
264258
with open(target_logs, 'r', encoding='utf-8') as f:
265259
logs = f.read()
266260
logs_to_return = shorten_logs_if_needed(logs)
267-
logger.info('Running test 5')
268261
logger.info("run-tests.sh for project '%s': {%s}", project_name,
269262
logs_to_return)
270-
logger.info('Running test 6')
271263
return logs_to_return
272264

273265

@@ -353,7 +345,6 @@ async def list_files(path: str = "") -> str:
353345
_internal_delay()
354346
target_dir = os.path.normpath(path)
355347
if not target_dir.startswith(oss_fuzz_mcp_config.BASE_DIR):
356-
# Security check to prevent directory traversal
357348
return FILE_ACCESS_ERROR
358349

359350
logger.info("Listing files in directory: %s", target_dir)
@@ -388,7 +379,6 @@ async def get_file_size(file_path) -> str:
388379
_internal_delay()
389380
target_file = os.path.normpath(file_path)
390381
if not target_file.startswith(oss_fuzz_mcp_config.BASE_DIR):
391-
# Security check to prevent directory traversal
392382
return FILE_ACCESS_ERROR
393383

394384
logger.info("Getting file size: %s", target_file)
@@ -420,7 +410,6 @@ async def read_file(file_path: str, start_idx: int, end_idx: int) -> str:
420410
_internal_delay()
421411
target_file = os.path.normpath(file_path)
422412
if not target_file.startswith(oss_fuzz_mcp_config.BASE_DIR):
423-
# Security check to prevent directory traversal
424413
return FILE_ACCESS_ERROR
425414

426415
logger.info("Reading file: %s", target_file)
@@ -462,7 +451,6 @@ async def write_file(file_path: str, content: str) -> str:
462451
logger.info("Writing to file: %s", file_path)
463452
target_file = os.path.normpath(file_path)
464453

465-
# Security check to prevent directory traversal
466454
if not target_file.startswith(oss_fuzz_mcp_config.BASE_DIR):
467455
return FILE_ACCESS_ERROR
468456

@@ -489,7 +477,6 @@ async def delete_file(file_path: str) -> str:
489477
logger.info("Deleting file: %s", file_path)
490478
target_file = os.path.normpath(file_path)
491479

492-
# Security check to prevent directory traversal
493480
if not target_file.startswith(oss_fuzz_mcp_config.BASE_DIR):
494481
return FILE_ACCESS_ERROR
495482

@@ -664,19 +651,17 @@ async def get_coverage_of_oss_fuzz_project(project_name):
664651
logger.info('Refined coverage dict: %s', json.dumps(refined_cov_dict,
665652
indent=2))
666653

667-
# Split up the coverage
668654
return json.dumps(refined_cov_dict, indent=2)
669655

670656

671657
def start_mcp_server():
672658
"""Starts the MCP server."""
673659
try:
674-
logger.info("Starting MCP server on port 8000...")
675-
# Close the HTTP client when the server shuts down
676-
mcp.run(transport="sse")
660+
logger.info("Starting MCP server.")
661+
mcp.run(transport="stdio")
677662
except KeyboardInterrupt:
678-
logger.info("Server shutting down...")
679-
asyncio.run(http_client.aclose())
663+
logger.info("Caught KeyboardInterrupt.")
664+
logger.info('Server shut down.')
680665

681666

682667
if __name__ == "__main__":

0 commit comments

Comments
 (0)