Skip to content

Commit 4773700

Browse files
authored
feat: Add method to extract session ID from headers with hyphen and underscore support (#601)
1 parent f44349b commit 4773700

File tree

4 files changed

+174
-50
lines changed

4 files changed

+174
-50
lines changed

templates/python-mcp-empty/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44
apify < 4.0.0
55
apify-client < 3.0.0
6-
fastmcp>=0.2.0
7-
mcp>=1.16.0
6+
fastmcp>=2.14.0
7+
mcp>=1.25.0
Lines changed: 149 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,34 @@
11
"""Main entry point for the MCP Server Actor."""
22

3+
import asyncio
34
import os
5+
import time
6+
from collections.abc import Mapping, MutableMapping
7+
from typing import Any
48

9+
import uvicorn
510
from apify import Actor
611
from fastmcp import FastMCP
7-
8-
# Initialize the Apify Actor environment
9-
# This call configures the Actor for its environment and should be called at startup
12+
from starlette.requests import Request
13+
from starlette.types import Receive, Scope, Send
1014

1115

1216
def get_server() -> FastMCP:
13-
"""Create an MCP server with implementation details."""
17+
"""Create an MCP server with tools and resources."""
1418
server = FastMCP('python-mcp-empty', '1.0.0')
1519

1620
@server.tool() # type: ignore[misc]
1721
def add(a: float, b: float) -> dict:
18-
"""Add two numbers together and return the sum with structured output.
19-
20-
Args:
21-
a: First number to add
22-
b: Second number to add
23-
24-
Returns:
25-
Dictionary with the sum result and structured output
26-
"""
27-
# Note: We can't await here in sync context, so charging happens in async wrapper
28-
sum_result = a + b
29-
structured_content = {
30-
'result': sum_result,
31-
'operands': {'a': a, 'b': b},
32-
'operation': 'addition',
33-
}
34-
22+
"""Add two numbers together and return the sum."""
23+
result = a + b
3524
return {
3625
'type': 'text',
37-
'text': f'The sum of {a} and {b} is {sum_result}',
38-
'structuredContent': structured_content,
26+
'text': f'The sum of {a} and {b} is {result}',
27+
'structuredContent': {
28+
'result': result,
29+
'operands': {'a': a, 'b': b},
30+
'operation': 'addition',
31+
},
3932
}
4033

4134
@server.resource(uri='https://example.com/calculator', name='calculator-info') # type: ignore[misc]
@@ -46,39 +39,150 @@ def calculator_info() -> str:
4639
return server
4740

4841

42+
def get_session_id(headers: Mapping[str, str]) -> str | None:
43+
"""Extract session ID from request headers."""
44+
for key in ('mcp-session-id', 'mcp_session_id'):
45+
if value := headers.get(key):
46+
return value
47+
return None
48+
49+
50+
class SessionTrackingMiddleware:
51+
"""ASGI middleware that tracks MCP sessions and closes idle ones."""
52+
53+
def __init__(self, app: Any, port: int, timeout_secs: int) -> None:
54+
self.app = app
55+
self.port = port
56+
self.timeout_secs = timeout_secs
57+
58+
# Session tracking state
59+
self._last_activity: dict[str, float] = {}
60+
self._timers: dict[str, asyncio.Task[None]] = {}
61+
62+
def _session_cleanup(self, sid: str) -> None:
63+
self._last_activity.pop(sid, None)
64+
if (timer := self._timers.pop(sid, None)) and not timer.done():
65+
timer.cancel()
66+
67+
def _touch(self, sid: str) -> None:
68+
self._last_activity[sid] = time.time()
69+
70+
# Cancel existing timer
71+
if (timer := self._timers.get(sid)) and not timer.done():
72+
timer.cancel()
73+
74+
async def close_if_idle() -> None:
75+
try:
76+
await asyncio.sleep(self.timeout_secs)
77+
78+
# Check if activity occurred during sleep
79+
elapsed = time.time() - self._last_activity.get(sid, 0)
80+
if elapsed < self.timeout_secs * 0.9:
81+
return
82+
83+
Actor.log.info(f'Closing idle session: {sid}')
84+
85+
# Send internal DELETE request to close session
86+
scope: Scope = {
87+
'type': 'http',
88+
'http_version': '1.1',
89+
'method': 'DELETE',
90+
'scheme': 'http',
91+
'path': '/mcp',
92+
'raw_path': b'/mcp',
93+
'query_string': b'',
94+
'headers': [(b'mcp-session-id', sid.encode())],
95+
'server': ('127.0.0.1', self.port),
96+
'client': ('127.0.0.1', 0),
97+
'_idle_close': True,
98+
}
99+
100+
async def noop_receive() -> MutableMapping[str, Any]:
101+
return {'type': 'http.request', 'body': b'', 'more_body': False}
102+
103+
async def noop_send(_: MutableMapping[str, Any]) -> None:
104+
pass
105+
106+
# Re-enter middleware with an internal DELETE; _idle_close will skip tracking
107+
await self(scope, noop_receive, noop_send)
108+
self._session_cleanup(sid)
109+
110+
except asyncio.CancelledError:
111+
pass
112+
except Exception as e:
113+
Actor.log.exception(f'Failed to close idle session {sid}: {e}')
114+
115+
self._timers[sid] = asyncio.create_task(close_if_idle())
116+
117+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
118+
"""ASGI entry point that wraps the underlying app."""
119+
# Pass through non-MCP requests
120+
path = scope.get('path', '')
121+
if scope.get('type') != 'http' or path not in ('/mcp', '/mcp/'):
122+
await self.app(scope, receive, send)
123+
return
124+
125+
# Skip tracking for internal idle-close requests
126+
if scope.get('_idle_close'):
127+
await self.app(scope, receive, send)
128+
return
129+
130+
request = Request(scope, receive)
131+
sid = get_session_id(request.headers)
132+
is_delete = scope.get('method') == 'DELETE'
133+
134+
# Track activity for existing sessions (skip DELETE)
135+
if sid and not is_delete:
136+
self._touch(sid)
137+
138+
# Capture new session ID from response headers
139+
new_sid: str | None = None
140+
141+
async def capture_send(msg: MutableMapping[str, Any]) -> None:
142+
nonlocal new_sid
143+
if msg.get('type') == 'http.response.start':
144+
for k, v in msg.get('headers', []):
145+
if k.decode().lower() == 'mcp-session-id':
146+
new_sid = v.decode()
147+
break
148+
await send(msg)
149+
150+
await self.app(scope, receive, capture_send)
151+
152+
# Track a newly created session
153+
if not sid and new_sid:
154+
Actor.log.info(f'New session: {new_sid}')
155+
self._touch(new_sid)
156+
157+
# Cleanup on explicit DELETE
158+
if is_delete and sid:
159+
Actor.log.info(f'Session closed: {sid}')
160+
self._session_cleanup(sid)
161+
162+
49163
async def main() -> None:
50-
"""Run the MCP Server Actor.
51-
52-
This function:
53-
1. Initializes the Actor
54-
2. Creates and configures the MCP server
55-
3. Starts the HTTP server with Streamable HTTP transport
56-
4. Handles MCP requests
57-
"""
164+
"""Run the MCP Server Actor with session timeout handling."""
58165
await Actor.init()
59166

60-
# Get port from environment or default to 3000
61167
port = int(os.environ.get('APIFY_CONTAINER_PORT', '3000'))
168+
timeout_secs = int(os.environ.get('SESSION_TIMEOUT_SECS', '300'))
62169

63170
server = get_server()
171+
app = server.http_app(transport='streamable-http')
172+
173+
# Wrap the app with session tracking middleware to handle idle timeouts
174+
app = SessionTrackingMiddleware(app=app, port=port, timeout_secs=timeout_secs)
64175

65176
try:
66-
Actor.log.info('Starting MCP server with FastMCP')
67-
68-
# Start the FastMCP server with HTTP transport
69-
# This starts the server on the specified port and handles MCP protocol messages
70-
await server.run_http_async(
71-
host='0.0.0.0', # noqa: S104 - Required for container networking
72-
port=port,
73-
)
177+
Actor.log.info(f'Starting MCP server on port {port} (session timeout: {timeout_secs}s)')
178+
config = uvicorn.Config(app, host='0.0.0.0', port=port, log_level='info') # noqa: S104
179+
await uvicorn.Server(config).serve()
74180
except KeyboardInterrupt:
75-
Actor.log.info('Shutting down server...')
76-
except Exception as error:
77-
Actor.log.error(f'Server failed to start: {error}')
181+
Actor.log.info('Shutting down...')
182+
except Exception as e:
183+
Actor.log.error(f'Server failed: {e}')
78184
raise
79185

80186

81187
if __name__ == '__main__':
82-
import asyncio
83-
84188
asyncio.run(main())

templates/python-mcp-proxy/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ apify-client < 3.0.0
66
arxiv-mcp-server==0.3.1
77
fastapi==0.119.1
88
httpx>=0.24.0
9-
mcp==1.23.0
9+
mcp>=1.25.0
1010
pydantic>=2.0.0
1111
uv>=0.7.8
1212
uvicorn>=0.27.0

templates/python-mcp-proxy/src/server.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,26 @@ async def capturing_send(message: dict[str, Any]) -> None:
263263

264264
return capturing_send
265265

266+
@staticmethod
267+
def _get_session_id_from_headers(headers: Any) -> str | None:
268+
"""Extract session ID from headers, trying both hyphen and underscore variants.
269+
270+
HTTP headers are case-insensitive per spec, and Starlette's Request.headers
271+
handles this automatically. We only need to check for hyphen vs underscore variants.
272+
273+
Args:
274+
headers: Either a Starlette Request.headers object or a dict
275+
276+
Returns:
277+
Session ID string if found, None otherwise
278+
"""
279+
# Try both hyphen and underscore variants
280+
# Case is handled automatically by Starlette's case-insensitive headers
281+
for key in ('mcp-session-id', 'mcp_session_id'):
282+
if value := headers.get(key):
283+
return value # type: ignore[no-any-return]
284+
return None
285+
266286
async def create_starlette_app(self, mcp_server: Server) -> Starlette:
267287
"""Create a Starlette app that exposes /mcp endpoint for Streamable HTTP transport."""
268288
event_store = InMemoryEventStore()
@@ -341,7 +361,7 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) ->
341361

342362
if scope['method'] == 'DELETE':
343363
await session_manager.handle_request(scope, receive, send)
344-
if req_sid := request.headers.get('mcp-session-id'):
364+
if req_sid := self._get_session_id_from_headers(request.headers):
345365
self._cleanup_session_last_activity(req_sid)
346366
self._cleanup_session_timer(req_sid)
347367
return
@@ -352,7 +372,7 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) ->
352372
capturing_send = self._create_capturing_send(send, session_id_from_resp)
353373

354374
# Log and touch existing session if present on request
355-
if req_sid := request.headers.get('mcp-session-id'):
375+
if req_sid := self._get_session_id_from_headers(request.headers):
356376
self._touch_session(req_sid, session_manager)
357377

358378
await session_manager.handle_request(scope, receive, capturing_send) # type: ignore[arg-type]

0 commit comments

Comments
 (0)