Skip to content

Commit 1c5c08e

Browse files
committed
feat(middleware): enhance middleware discovery and registration
1 parent ebaaa55 commit 1c5c08e

File tree

3 files changed

+197
-60
lines changed

3 files changed

+197
-60
lines changed

src/golf/core/builder.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,15 +1227,22 @@ def _generate_server(self) -> None:
12271227
component_registrations.append("")
12281228

12291229
# Check for custom middleware.py file and register middleware classes
1230-
middleware_classes = self._discover_middleware_classes(self.project_path)
1231-
if middleware_classes:
1230+
discovered_middleware = self._discover_middleware_classes(self.project_path)
1231+
fastmcp_middleware = discovered_middleware.get("fastmcp", [])
1232+
starlette_middleware = discovered_middleware.get("starlette", [])
1233+
1234+
# Import all middleware classes
1235+
all_middleware = fastmcp_middleware + starlette_middleware
1236+
if all_middleware:
12321237
imports.append("# Import custom middleware")
1233-
imports.append("from middleware import " + ", ".join(middleware_classes))
1238+
imports.append("from middleware import " + ", ".join(all_middleware))
12341239
imports.append("")
12351240

1236-
for cls_name in middleware_classes:
1237-
# Add each discovered middleware class to FastMCP
1238-
component_registrations.append(f"# Register custom middleware: {cls_name}")
1241+
# Register only FastMCP middleware via mcp.add_middleware()
1242+
# Starlette HTTP middleware will be added to mcp.run(middleware=[...]) later
1243+
if fastmcp_middleware:
1244+
for cls_name in fastmcp_middleware:
1245+
component_registrations.append(f"# Register custom FastMCP middleware: {cls_name}")
12391246
component_registrations.append(f"mcp.add_middleware({cls_name}())")
12401247
component_registrations.append("")
12411248

@@ -1343,7 +1350,14 @@ def _generate_server(self) -> None:
13431350
middleware_setup.append(" from golf.telemetry.instrumentation import SessionTracingMiddleware")
13441351
middleware_list.append("Middleware(SessionTracingMiddleware)")
13451352

1346-
if middleware_setup:
1353+
# Add custom Starlette HTTP middleware (e.g., CacheControlMiddleware)
1354+
# These are wrapped in Middleware() and passed to mcp.run(), not mcp.add_middleware()
1355+
if starlette_middleware:
1356+
middleware_setup.append(" from starlette.middleware import Middleware")
1357+
for cls_name in starlette_middleware:
1358+
middleware_list.append(f"Middleware({cls_name})")
1359+
1360+
if middleware_setup or starlette_middleware:
13471361
main_code.extend(middleware_setup)
13481362
main_code.append(f" middleware = [{', '.join(middleware_list)}]")
13491363
main_code.append("")
@@ -1405,7 +1419,14 @@ def _generate_server(self) -> None:
14051419
middleware_setup.append(" from golf.telemetry.instrumentation import SessionTracingMiddleware")
14061420
middleware_list.append("Middleware(SessionTracingMiddleware)")
14071421

1408-
if middleware_setup:
1422+
# Add custom Starlette HTTP middleware (e.g., CacheControlMiddleware)
1423+
# These are wrapped in Middleware() and passed to mcp.run(), not mcp.add_middleware()
1424+
if starlette_middleware:
1425+
middleware_setup.append(" from starlette.middleware import Middleware")
1426+
for cls_name in starlette_middleware:
1427+
middleware_list.append(f"Middleware({cls_name})")
1428+
1429+
if middleware_setup or starlette_middleware:
14091430
main_code.extend(middleware_setup)
14101431
main_code.append(f" middleware = [{', '.join(middleware_list)}]")
14111432
main_code.append("")
@@ -1493,11 +1514,16 @@ def _generate_server(self) -> None:
14931514
with open(server_file, "w") as f:
14941515
f.write(code)
14951516

1496-
def _discover_middleware_classes(self, project_path: Path) -> list[str]:
1497-
"""Discover middleware classes from middleware.py file."""
1517+
def _discover_middleware_classes(self, project_path: Path) -> dict[str, list[str]]:
1518+
"""Discover middleware classes from middleware.py file.
1519+
1520+
Returns a dict with two keys:
1521+
- 'fastmcp': List of FastMCP middleware class names (use mcp.add_middleware())
1522+
- 'starlette': List of Starlette HTTP middleware class names (use middleware=[])
1523+
"""
14981524
middleware_path = project_path / "middleware.py"
14991525
if not middleware_path.exists():
1500-
return []
1526+
return {"fastmcp": [], "starlette": []}
15011527

15021528
try:
15031529
# Save current directory and path
@@ -1513,25 +1539,39 @@ def _discover_middleware_classes(self, project_path: Path) -> list[str]:
15131539

15141540
spec = importlib.util.spec_from_file_location("middleware", middleware_path)
15151541
if spec is None or spec.loader is None:
1516-
return []
1542+
return {"fastmcp": [], "starlette": []}
15171543
middleware_module = importlib.util.module_from_spec(spec)
15181544
spec.loader.exec_module(middleware_module)
15191545

1520-
# Auto-discover middleware classes using duck typing
1521-
middleware_classes = []
1546+
# Auto-discover middleware classes, distinguishing between FastMCP and Starlette
1547+
fastmcp_middleware = []
1548+
starlette_middleware = []
1549+
1550+
# FastMCP middleware methods (MCP protocol level)
1551+
fastmcp_methods = ["on_message", "on_request", "on_call_tool", "on_read_resource", "on_get_prompt", "on_initialize"]
1552+
# Starlette/ASGI middleware method (HTTP level)
1553+
starlette_method = "dispatch"
1554+
15221555
for name, obj in inspect.getmembers(middleware_module, inspect.isclass):
15231556
# Skip classes that are not defined in this module (imported classes)
15241557
if obj.__module__ != middleware_module.__name__:
15251558
continue
15261559

1527-
# Check if class actually implements middleware methods (not just inherits them)
1528-
middleware_methods = ["on_message", "on_request", "on_call_tool", "dispatch"]
1529-
has_implemented_method = any(method in obj.__dict__ for method in middleware_methods)
1530-
if has_implemented_method:
1531-
middleware_classes.append(name)
1532-
console.print(f"[green]Discovered middleware: {name}[/green]")
1560+
# Check if class implements FastMCP middleware methods
1561+
has_fastmcp_method = any(method in obj.__dict__ for method in fastmcp_methods)
1562+
# Check if class implements Starlette dispatch method
1563+
has_dispatch_method = starlette_method in obj.__dict__
1564+
1565+
if has_fastmcp_method and not has_dispatch_method:
1566+
# Pure FastMCP middleware
1567+
fastmcp_middleware.append(name)
1568+
console.print(f"[green]Discovered FastMCP middleware: {name}[/green]")
1569+
elif has_dispatch_method:
1570+
# Starlette/ASGI HTTP middleware (dispatch method indicates HTTP-level)
1571+
starlette_middleware.append(name)
1572+
console.print(f"[green]Discovered Starlette HTTP middleware: {name}[/green]")
15331573

1534-
return middleware_classes
1574+
return {"fastmcp": fastmcp_middleware, "starlette": starlette_middleware}
15351575

15361576
except Exception as e:
15371577
console.print(f"[yellow]Warning: Could not load middleware.py: {e}[/yellow]")
@@ -1555,7 +1595,7 @@ def _discover_middleware_classes(self, project_path: Path) -> list[str]:
15551595
except Exception:
15561596
pass
15571597

1558-
return []
1598+
return {"fastmcp": [], "starlette": []}
15591599

15601600
finally:
15611601
# Always restore original directory and path

tests/core/test_middleware.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ class NotMiddleware:
2424

2525
settings = load_settings(sample_project)
2626
generator = CodeGenerator(sample_project, settings, temp_dir)
27-
middleware_classes = generator._discover_middleware_classes(sample_project)
28-
29-
assert middleware_classes == ["LoggingMiddleware"]
27+
discovered = generator._discover_middleware_classes(sample_project)
28+
29+
# FastMCP middleware (on_call_tool method)
30+
assert discovered["fastmcp"] == ["LoggingMiddleware"]
31+
assert discovered["starlette"] == []
3032

3133
def test_discovers_multiple_middleware_methods(self, sample_project: Path, temp_dir: Path):
3234
"""Test discovery of middleware with different methods."""
@@ -49,18 +51,20 @@ def dispatch(self, context):
4951

5052
settings = load_settings(sample_project)
5153
generator = CodeGenerator(sample_project, settings, temp_dir)
52-
middleware_classes = generator._discover_middleware_classes(sample_project)
53-
54-
# Should discover all three middleware classes
55-
assert set(middleware_classes) == {"MessageMiddleware", "RequestMiddleware", "DispatchMiddleware"}
54+
discovered = generator._discover_middleware_classes(sample_project)
55+
56+
# FastMCP middleware (on_message, on_request methods)
57+
assert set(discovered["fastmcp"]) == {"MessageMiddleware", "RequestMiddleware"}
58+
# Starlette HTTP middleware (dispatch method)
59+
assert discovered["starlette"] == ["DispatchMiddleware"]
5660

5761
def test_no_middleware_when_file_missing(self, sample_project: Path, temp_dir: Path):
5862
"""Test graceful handling when middleware.py doesn't exist."""
5963
settings = load_settings(sample_project)
6064
generator = CodeGenerator(sample_project, settings, temp_dir)
61-
middleware_classes = generator._discover_middleware_classes(sample_project)
62-
63-
assert middleware_classes == []
65+
discovered = generator._discover_middleware_classes(sample_project)
66+
67+
assert discovered == {"fastmcp": [], "starlette": []}
6468

6569
def test_ignores_classes_without_middleware_methods(self, sample_project: Path, temp_dir: Path):
6670
"""Test that classes without middleware methods are ignored."""
@@ -83,9 +87,10 @@ def regular_method(self):
8387

8488
settings = load_settings(sample_project)
8589
generator = CodeGenerator(sample_project, settings, temp_dir)
86-
middleware_classes = generator._discover_middleware_classes(sample_project)
87-
88-
assert middleware_classes == ["ValidMiddleware"]
90+
discovered = generator._discover_middleware_classes(sample_project)
91+
92+
assert discovered["fastmcp"] == ["ValidMiddleware"]
93+
assert discovered["starlette"] == []
8994

9095

9196
class TestMiddlewareCodeGeneration:
@@ -168,9 +173,10 @@ async def on_message(self, context, call_next):
168173

169174
settings = load_settings(sample_project)
170175
generator = CodeGenerator(sample_project, settings, temp_dir)
171-
middleware_classes = generator._discover_middleware_classes(sample_project)
172-
173-
assert middleware_classes == [] # Should return empty list on error
176+
discovered = generator._discover_middleware_classes(sample_project)
177+
178+
# Should return empty dict on error
179+
assert discovered == {"fastmcp": [], "starlette": []}
174180

175181
def test_handles_import_error(self, sample_project: Path, temp_dir: Path):
176182
"""Test graceful handling of import errors."""
@@ -185,9 +191,10 @@ async def on_message(self, context, call_next):
185191

186192
settings = load_settings(sample_project)
187193
generator = CodeGenerator(sample_project, settings, temp_dir)
188-
middleware_classes = generator._discover_middleware_classes(sample_project)
189-
190-
assert middleware_classes == [] # Should return empty list on error
194+
discovered = generator._discover_middleware_classes(sample_project)
195+
196+
# Should return empty dict on error
197+
assert discovered == {"fastmcp": [], "starlette": []}
191198

192199
def test_build_succeeds_with_broken_middleware(self, sample_project: Path, temp_dir: Path):
193200
"""Test that broken middleware doesn't break the build process."""
@@ -226,9 +233,10 @@ async def on_message(self, context, call_next):
226233

227234
settings = load_settings(sample_project)
228235
generator = CodeGenerator(sample_project, settings, temp_dir)
229-
middleware_classes = generator._discover_middleware_classes(sample_project)
230-
231-
assert middleware_classes == [] # Should return empty list on error
236+
discovered = generator._discover_middleware_classes(sample_project)
237+
238+
# Should return empty dict on error
239+
assert discovered == {"fastmcp": [], "starlette": []}
232240

233241
def test_handles_empty_middleware_file(self, sample_project: Path, temp_dir: Path):
234242
"""Test handling of empty middleware.py file."""
@@ -237,9 +245,9 @@ def test_handles_empty_middleware_file(self, sample_project: Path, temp_dir: Pat
237245

238246
settings = load_settings(sample_project)
239247
generator = CodeGenerator(sample_project, settings, temp_dir)
240-
middleware_classes = generator._discover_middleware_classes(sample_project)
241-
242-
assert middleware_classes == []
248+
discovered = generator._discover_middleware_classes(sample_project)
249+
250+
assert discovered == {"fastmcp": [], "starlette": []}
243251

244252

245253
class TestMiddlewareDuckTyping:
@@ -265,12 +273,15 @@ def some_other_method(self):
265273

266274
settings = load_settings(sample_project)
267275
generator = CodeGenerator(sample_project, settings, temp_dir)
268-
middleware_classes = generator._discover_middleware_classes(sample_project)
269-
270-
assert len(middleware_classes) == 2
271-
assert "DuckTypedMiddleware" in middleware_classes
272-
assert "AlsoMiddleware" in middleware_classes
273-
assert "NotMiddleware" not in middleware_classes
276+
discovered = generator._discover_middleware_classes(sample_project)
277+
278+
# FastMCP middleware (on_call_tool method)
279+
assert discovered["fastmcp"] == ["DuckTypedMiddleware"]
280+
# Starlette HTTP middleware (dispatch method)
281+
assert discovered["starlette"] == ["AlsoMiddleware"]
282+
# NotMiddleware should not appear in either list
283+
all_middleware = discovered["fastmcp"] + discovered["starlette"]
284+
assert "NotMiddleware" not in all_middleware
274285

275286
def test_discovers_mixed_middleware_types(self, sample_project: Path, temp_dir: Path):
276287
"""Test discovery of middleware with and without base class."""
@@ -293,12 +304,14 @@ def regular_method(self):
293304

294305
settings = load_settings(sample_project)
295306
generator = CodeGenerator(sample_project, settings, temp_dir)
296-
middleware_classes = generator._discover_middleware_classes(sample_project)
297-
298-
assert len(middleware_classes) == 2
299-
assert "InheritedMiddleware" in middleware_classes
300-
assert "DuckTypedMiddleware" in middleware_classes
301-
assert "NoMiddlewareMethods" not in middleware_classes
307+
discovered = generator._discover_middleware_classes(sample_project)
308+
309+
# Both are FastMCP middleware (on_message, on_call_tool methods)
310+
assert set(discovered["fastmcp"]) == {"InheritedMiddleware", "DuckTypedMiddleware"}
311+
assert discovered["starlette"] == []
312+
# NoMiddlewareMethods should not appear
313+
all_middleware = discovered["fastmcp"] + discovered["starlette"]
314+
assert "NoMiddlewareMethods" not in all_middleware
302315

303316

304317
class TestMiddlewareRegistrationOrder:

tests/core/test_middleware_build.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,4 +302,88 @@ def some_method(self):
302302
assert "mcp.add_middleware(DuckTypedLogging())" in server_content
303303
assert "mcp.add_middleware(DuckTypedAuth())" in server_content
304304
# Non-middleware class should not be included
305-
assert "NotMiddleware" not in server_content
305+
assert "NotMiddleware" not in server_content
306+
307+
def test_starlette_http_middleware_build(self, sample_project: Path, temp_dir: Path):
308+
"""Test Starlette HTTP middleware (e.g., CacheControlMiddleware) is registered correctly.
309+
310+
Starlette middleware uses the dispatch() method and must be passed to mcp.run(middleware=[])
311+
instead of mcp.add_middleware() which is for FastMCP protocol-level middleware.
312+
"""
313+
middleware_file = sample_project / "middleware.py"
314+
middleware_file.write_text('''
315+
from starlette.middleware.base import BaseHTTPMiddleware
316+
from starlette.requests import Request
317+
from starlette.responses import Response
318+
from typing import Callable, Any
319+
320+
321+
class CacheControlMiddleware(BaseHTTPMiddleware):
322+
"""Middleware to add Cache-Control headers to all responses."""
323+
324+
async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Response:
325+
response = await call_next(request)
326+
response.headers["Cache-Control"] = "no-store"
327+
return response
328+
''')
329+
330+
settings = load_settings(sample_project)
331+
output_dir = temp_dir / "build"
332+
build_project(sample_project, settings, output_dir, build_env="dev", copy_env=False)
333+
334+
server_content = (output_dir / "server.py").read_text()
335+
336+
# Starlette middleware should be imported
337+
assert "from middleware import CacheControlMiddleware" in server_content
338+
339+
# Should NOT be registered via mcp.add_middleware() - that would fail!
340+
assert "mcp.add_middleware(CacheControlMiddleware())" not in server_content
341+
342+
# Should be added to the middleware list for mcp.run()
343+
assert "Middleware(CacheControlMiddleware)" in server_content
344+
345+
def test_mixed_fastmcp_and_starlette_middleware_build(self, sample_project: Path, temp_dir: Path):
346+
"""Test that both FastMCP and Starlette middleware can be used together."""
347+
middleware_file = sample_project / "middleware.py"
348+
middleware_file.write_text('''
349+
from starlette.middleware.base import BaseHTTPMiddleware
350+
from golf.middleware import Middleware as FastMCPMiddleware
351+
from typing import Callable, Any
352+
353+
354+
class LoggingMiddleware(FastMCPMiddleware):
355+
"""FastMCP middleware for logging MCP operations."""
356+
357+
async def on_call_tool(self, context, call_next):
358+
print(f"Calling tool: {context.message.params.name}")
359+
return await call_next(context)
360+
361+
362+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
363+
"""Starlette middleware for adding security headers."""
364+
365+
async def dispatch(self, request, call_next: Callable[..., Any]):
366+
response = await call_next(request)
367+
response.headers["Cache-Control"] = "no-store"
368+
response.headers["X-Content-Type-Options"] = "nosniff"
369+
return response
370+
''')
371+
372+
settings = load_settings(sample_project)
373+
output_dir = temp_dir / "build"
374+
build_project(sample_project, settings, output_dir, build_env="dev", copy_env=False)
375+
376+
server_content = (output_dir / "server.py").read_text()
377+
378+
# Both middleware should be imported
379+
assert "LoggingMiddleware" in server_content
380+
assert "SecurityHeadersMiddleware" in server_content
381+
382+
# FastMCP middleware should use mcp.add_middleware()
383+
assert "mcp.add_middleware(LoggingMiddleware())" in server_content
384+
385+
# Starlette middleware should NOT use mcp.add_middleware()
386+
assert "mcp.add_middleware(SecurityHeadersMiddleware())" not in server_content
387+
388+
# Starlette middleware should be in the middleware list for mcp.run()
389+
assert "Middleware(SecurityHeadersMiddleware)" in server_content

0 commit comments

Comments
 (0)