Skip to content

Commit eb777a0

Browse files
authored
Merge pull request #132 from golf-mcp/aschlean/starlette-middleware
Aschlean/starlette middleware
2 parents 6b7a81d + 888ea5a commit eb777a0

File tree

5 files changed

+207
-63
lines changed

5 files changed

+207
-63
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "golf-mcp"
7-
version = "0.3.0rc3"
7+
version = "0.3.0rc5"
88
description = "Framework for building MCP servers"
99
authors = [
1010
{name = "Antoni Gmitruk", email = "antoni@golf.dev"}
@@ -66,7 +66,7 @@ golf = ["examples/**/*"]
6666

6767
[tool.poetry]
6868
name = "golf-mcp"
69-
version = "0.3.0rc3"
69+
version = "0.3.0rc5"
7070
description = "Framework for building MCP servers with zero boilerplate"
7171
authors = ["Antoni Gmitruk <antoni@golf.dev>"]
7272
license = "Apache-2.0"

src/golf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.3.0rc3"
1+
__version__ = "0.3.0rc5"
22

33
from golf.decorators import prompt, resource, tool
44

src/golf/core/builder.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,15 +1227,22 @@ def _generate_server(self) -> None:
12271227
component_registrations.append("")
12281228

12291229
# Check for custom middleware.py file and register middleware classes
1230-
middleware_classes = self._discover_middleware_classes(self.project_path)
1231-
if middleware_classes:
1230+
discovered_middleware = self._discover_middleware_classes(self.project_path)
1231+
fastmcp_middleware = discovered_middleware.get("fastmcp", [])
1232+
starlette_middleware = discovered_middleware.get("starlette", [])
1233+
1234+
# Import all middleware classes
1235+
all_middleware = fastmcp_middleware + starlette_middleware
1236+
if all_middleware:
12321237
imports.append("# Import custom middleware")
1233-
imports.append("from middleware import " + ", ".join(middleware_classes))
1238+
imports.append("from middleware import " + ", ".join(all_middleware))
12341239
imports.append("")
12351240

1236-
for cls_name in middleware_classes:
1237-
# Add each discovered middleware class to FastMCP
1238-
component_registrations.append(f"# Register custom middleware: {cls_name}")
1241+
# Register only FastMCP middleware via mcp.add_middleware()
1242+
# Starlette HTTP middleware will be added to mcp.run(middleware=[...]) later
1243+
if fastmcp_middleware:
1244+
for cls_name in fastmcp_middleware:
1245+
component_registrations.append(f"# Register custom FastMCP middleware: {cls_name}")
12391246
component_registrations.append(f"mcp.add_middleware({cls_name}())")
12401247
component_registrations.append("")
12411248

@@ -1343,7 +1350,14 @@ def _generate_server(self) -> None:
13431350
middleware_setup.append(" from golf.telemetry.instrumentation import SessionTracingMiddleware")
13441351
middleware_list.append("Middleware(SessionTracingMiddleware)")
13451352

1346-
if middleware_setup:
1353+
# Add custom Starlette HTTP middleware (e.g., CacheControlMiddleware)
1354+
# These are wrapped in Middleware() and passed to mcp.run(), not mcp.add_middleware()
1355+
if starlette_middleware:
1356+
middleware_setup.append(" from starlette.middleware import Middleware")
1357+
for cls_name in starlette_middleware:
1358+
middleware_list.append(f"Middleware({cls_name})")
1359+
1360+
if middleware_setup or starlette_middleware:
13471361
main_code.extend(middleware_setup)
13481362
main_code.append(f" middleware = [{', '.join(middleware_list)}]")
13491363
main_code.append("")
@@ -1405,7 +1419,14 @@ def _generate_server(self) -> None:
14051419
middleware_setup.append(" from golf.telemetry.instrumentation import SessionTracingMiddleware")
14061420
middleware_list.append("Middleware(SessionTracingMiddleware)")
14071421

1408-
if middleware_setup:
1422+
# Add custom Starlette HTTP middleware (e.g., CacheControlMiddleware)
1423+
# These are wrapped in Middleware() and passed to mcp.run(), not mcp.add_middleware()
1424+
if starlette_middleware:
1425+
middleware_setup.append(" from starlette.middleware import Middleware")
1426+
for cls_name in starlette_middleware:
1427+
middleware_list.append(f"Middleware({cls_name})")
1428+
1429+
if middleware_setup or starlette_middleware:
14091430
main_code.extend(middleware_setup)
14101431
main_code.append(f" middleware = [{', '.join(middleware_list)}]")
14111432
main_code.append("")
@@ -1493,11 +1514,16 @@ def _generate_server(self) -> None:
14931514
with open(server_file, "w") as f:
14941515
f.write(code)
14951516

1496-
def _discover_middleware_classes(self, project_path: Path) -> list[str]:
1497-
"""Discover middleware classes from middleware.py file."""
1517+
def _discover_middleware_classes(self, project_path: Path) -> dict[str, list[str]]:
1518+
"""Discover middleware classes from middleware.py file.
1519+
1520+
Returns a dict with two keys:
1521+
- 'fastmcp': List of FastMCP middleware class names (use mcp.add_middleware())
1522+
- 'starlette': List of Starlette HTTP middleware class names (use middleware=[])
1523+
"""
14981524
middleware_path = project_path / "middleware.py"
14991525
if not middleware_path.exists():
1500-
return []
1526+
return {"fastmcp": [], "starlette": []}
15011527

15021528
try:
15031529
# Save current directory and path
@@ -1513,25 +1539,46 @@ def _discover_middleware_classes(self, project_path: Path) -> list[str]:
15131539

15141540
spec = importlib.util.spec_from_file_location("middleware", middleware_path)
15151541
if spec is None or spec.loader is None:
1516-
return []
1542+
return {"fastmcp": [], "starlette": []}
15171543
middleware_module = importlib.util.module_from_spec(spec)
15181544
spec.loader.exec_module(middleware_module)
15191545

1520-
# Auto-discover middleware classes using duck typing
1521-
middleware_classes = []
1546+
# Auto-discover middleware classes, distinguishing between FastMCP and Starlette
1547+
fastmcp_middleware = []
1548+
starlette_middleware = []
1549+
1550+
# FastMCP middleware methods (MCP protocol level)
1551+
fastmcp_methods = [
1552+
"on_message",
1553+
"on_request",
1554+
"on_call_tool",
1555+
"on_read_resource",
1556+
"on_get_prompt",
1557+
"on_initialize",
1558+
]
1559+
# Starlette/ASGI middleware method (HTTP level)
1560+
starlette_method = "dispatch"
1561+
15221562
for name, obj in inspect.getmembers(middleware_module, inspect.isclass):
15231563
# Skip classes that are not defined in this module (imported classes)
15241564
if obj.__module__ != middleware_module.__name__:
15251565
continue
15261566

1527-
# Check if class actually implements middleware methods (not just inherits them)
1528-
middleware_methods = ["on_message", "on_request", "on_call_tool", "dispatch"]
1529-
has_implemented_method = any(method in obj.__dict__ for method in middleware_methods)
1530-
if has_implemented_method:
1531-
middleware_classes.append(name)
1532-
console.print(f"[green]Discovered middleware: {name}[/green]")
1567+
# Check if class implements FastMCP middleware methods
1568+
has_fastmcp_method = any(method in obj.__dict__ for method in fastmcp_methods)
1569+
# Check if class implements Starlette dispatch method
1570+
has_dispatch_method = starlette_method in obj.__dict__
1571+
1572+
if has_fastmcp_method and not has_dispatch_method:
1573+
# Pure FastMCP middleware
1574+
fastmcp_middleware.append(name)
1575+
console.print(f"[green]Discovered FastMCP middleware: {name}[/green]")
1576+
elif has_dispatch_method:
1577+
# Starlette/ASGI HTTP middleware (dispatch method indicates HTTP-level)
1578+
starlette_middleware.append(name)
1579+
console.print(f"[green]Discovered Starlette HTTP middleware: {name}[/green]")
15331580

1534-
return middleware_classes
1581+
return {"fastmcp": fastmcp_middleware, "starlette": starlette_middleware}
15351582

15361583
except Exception as e:
15371584
console.print(f"[yellow]Warning: Could not load middleware.py: {e}[/yellow]")
@@ -1555,7 +1602,7 @@ def _discover_middleware_classes(self, project_path: Path) -> list[str]:
15551602
except Exception:
15561603
pass
15571604

1558-
return []
1605+
return {"fastmcp": [], "starlette": []}
15591606

15601607
finally:
15611608
# Always restore original directory and path

tests/core/test_middleware.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ class NotMiddleware:
2424

2525
settings = load_settings(sample_project)
2626
generator = CodeGenerator(sample_project, settings, temp_dir)
27-
middleware_classes = generator._discover_middleware_classes(sample_project)
28-
29-
assert middleware_classes == ["LoggingMiddleware"]
27+
discovered = generator._discover_middleware_classes(sample_project)
28+
29+
# FastMCP middleware (on_call_tool method)
30+
assert discovered["fastmcp"] == ["LoggingMiddleware"]
31+
assert discovered["starlette"] == []
3032

3133
def test_discovers_multiple_middleware_methods(self, sample_project: Path, temp_dir: Path):
3234
"""Test discovery of middleware with different methods."""
@@ -49,18 +51,20 @@ def dispatch(self, context):
4951

5052
settings = load_settings(sample_project)
5153
generator = CodeGenerator(sample_project, settings, temp_dir)
52-
middleware_classes = generator._discover_middleware_classes(sample_project)
53-
54-
# Should discover all three middleware classes
55-
assert set(middleware_classes) == {"MessageMiddleware", "RequestMiddleware", "DispatchMiddleware"}
54+
discovered = generator._discover_middleware_classes(sample_project)
55+
56+
# FastMCP middleware (on_message, on_request methods)
57+
assert set(discovered["fastmcp"]) == {"MessageMiddleware", "RequestMiddleware"}
58+
# Starlette HTTP middleware (dispatch method)
59+
assert discovered["starlette"] == ["DispatchMiddleware"]
5660

5761
def test_no_middleware_when_file_missing(self, sample_project: Path, temp_dir: Path):
5862
"""Test graceful handling when middleware.py doesn't exist."""
5963
settings = load_settings(sample_project)
6064
generator = CodeGenerator(sample_project, settings, temp_dir)
61-
middleware_classes = generator._discover_middleware_classes(sample_project)
62-
63-
assert middleware_classes == []
65+
discovered = generator._discover_middleware_classes(sample_project)
66+
67+
assert discovered == {"fastmcp": [], "starlette": []}
6468

6569
def test_ignores_classes_without_middleware_methods(self, sample_project: Path, temp_dir: Path):
6670
"""Test that classes without middleware methods are ignored."""
@@ -83,9 +87,10 @@ def regular_method(self):
8387

8488
settings = load_settings(sample_project)
8589
generator = CodeGenerator(sample_project, settings, temp_dir)
86-
middleware_classes = generator._discover_middleware_classes(sample_project)
87-
88-
assert middleware_classes == ["ValidMiddleware"]
90+
discovered = generator._discover_middleware_classes(sample_project)
91+
92+
assert discovered["fastmcp"] == ["ValidMiddleware"]
93+
assert discovered["starlette"] == []
8994

9095

9196
class TestMiddlewareCodeGeneration:
@@ -168,9 +173,10 @@ async def on_message(self, context, call_next):
168173

169174
settings = load_settings(sample_project)
170175
generator = CodeGenerator(sample_project, settings, temp_dir)
171-
middleware_classes = generator._discover_middleware_classes(sample_project)
172-
173-
assert middleware_classes == [] # Should return empty list on error
176+
discovered = generator._discover_middleware_classes(sample_project)
177+
178+
# Should return empty dict on error
179+
assert discovered == {"fastmcp": [], "starlette": []}
174180

175181
def test_handles_import_error(self, sample_project: Path, temp_dir: Path):
176182
"""Test graceful handling of import errors."""
@@ -185,9 +191,10 @@ async def on_message(self, context, call_next):
185191

186192
settings = load_settings(sample_project)
187193
generator = CodeGenerator(sample_project, settings, temp_dir)
188-
middleware_classes = generator._discover_middleware_classes(sample_project)
189-
190-
assert middleware_classes == [] # Should return empty list on error
194+
discovered = generator._discover_middleware_classes(sample_project)
195+
196+
# Should return empty dict on error
197+
assert discovered == {"fastmcp": [], "starlette": []}
191198

192199
def test_build_succeeds_with_broken_middleware(self, sample_project: Path, temp_dir: Path):
193200
"""Test that broken middleware doesn't break the build process."""
@@ -226,9 +233,10 @@ async def on_message(self, context, call_next):
226233

227234
settings = load_settings(sample_project)
228235
generator = CodeGenerator(sample_project, settings, temp_dir)
229-
middleware_classes = generator._discover_middleware_classes(sample_project)
230-
231-
assert middleware_classes == [] # Should return empty list on error
236+
discovered = generator._discover_middleware_classes(sample_project)
237+
238+
# Should return empty dict on error
239+
assert discovered == {"fastmcp": [], "starlette": []}
232240

233241
def test_handles_empty_middleware_file(self, sample_project: Path, temp_dir: Path):
234242
"""Test handling of empty middleware.py file."""
@@ -237,9 +245,9 @@ def test_handles_empty_middleware_file(self, sample_project: Path, temp_dir: Pat
237245

238246
settings = load_settings(sample_project)
239247
generator = CodeGenerator(sample_project, settings, temp_dir)
240-
middleware_classes = generator._discover_middleware_classes(sample_project)
241-
242-
assert middleware_classes == []
248+
discovered = generator._discover_middleware_classes(sample_project)
249+
250+
assert discovered == {"fastmcp": [], "starlette": []}
243251

244252

245253
class TestMiddlewareDuckTyping:
@@ -265,12 +273,15 @@ def some_other_method(self):
265273

266274
settings = load_settings(sample_project)
267275
generator = CodeGenerator(sample_project, settings, temp_dir)
268-
middleware_classes = generator._discover_middleware_classes(sample_project)
269-
270-
assert len(middleware_classes) == 2
271-
assert "DuckTypedMiddleware" in middleware_classes
272-
assert "AlsoMiddleware" in middleware_classes
273-
assert "NotMiddleware" not in middleware_classes
276+
discovered = generator._discover_middleware_classes(sample_project)
277+
278+
# FastMCP middleware (on_call_tool method)
279+
assert discovered["fastmcp"] == ["DuckTypedMiddleware"]
280+
# Starlette HTTP middleware (dispatch method)
281+
assert discovered["starlette"] == ["AlsoMiddleware"]
282+
# NotMiddleware should not appear in either list
283+
all_middleware = discovered["fastmcp"] + discovered["starlette"]
284+
assert "NotMiddleware" not in all_middleware
274285

275286
def test_discovers_mixed_middleware_types(self, sample_project: Path, temp_dir: Path):
276287
"""Test discovery of middleware with and without base class."""
@@ -293,12 +304,14 @@ def regular_method(self):
293304

294305
settings = load_settings(sample_project)
295306
generator = CodeGenerator(sample_project, settings, temp_dir)
296-
middleware_classes = generator._discover_middleware_classes(sample_project)
297-
298-
assert len(middleware_classes) == 2
299-
assert "InheritedMiddleware" in middleware_classes
300-
assert "DuckTypedMiddleware" in middleware_classes
301-
assert "NoMiddlewareMethods" not in middleware_classes
307+
discovered = generator._discover_middleware_classes(sample_project)
308+
309+
# Both are FastMCP middleware (on_message, on_call_tool methods)
310+
assert set(discovered["fastmcp"]) == {"InheritedMiddleware", "DuckTypedMiddleware"}
311+
assert discovered["starlette"] == []
312+
# NoMiddlewareMethods should not appear
313+
all_middleware = discovered["fastmcp"] + discovered["starlette"]
314+
assert "NoMiddlewareMethods" not in all_middleware
302315

303316

304317
class TestMiddlewareRegistrationOrder:

0 commit comments

Comments
 (0)