Skip to content

Commit 6832394

Browse files
refactor: consolidate header validation into single method
- Add _validate_request_headers method that combines session and protocol validation - Replace repeated calls to _validate_session and _validate_protocol_version - Improves code maintainability and extensibility for future header validations - No functional changes, all tests passing This refactoring makes it easier to add new header validations in the future by having a single entry point for all non-initialization request validations.
1 parent 370b993 commit 6832394

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/mcp/server/streamable_http.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
355355
)
356356
await response(scope, receive, send)
357357
return
358-
elif not await self._validate_session(request, send):
359-
return
360-
elif not await self._validate_protocol_version(request, send):
358+
elif not await self._validate_request_headers(request, send):
361359
return
362360

363361
# For notifications and responses only, return 202 Accepted
@@ -516,9 +514,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
516514
await response(request.scope, request.receive, send)
517515
return
518516

519-
if not await self._validate_session(request, send):
520-
return
521-
if not await self._validate_protocol_version(request, send):
517+
if not await self._validate_request_headers(request, send):
522518
return
523519

524520
# Handle resumability: check for Last-Event-ID header
@@ -599,9 +595,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
599595
await response(request.scope, request.receive, send)
600596
return
601597

602-
if not await self._validate_session(request, send):
603-
return
604-
if not await self._validate_protocol_version(request, send):
598+
if not await self._validate_request_headers(request, send):
605599
return
606600

607601
await self._terminate_session()
@@ -661,6 +655,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
661655
)
662656
await response(request.scope, request.receive, send)
663657

658+
async def _validate_request_headers(self, request: Request, send: Send) -> bool:
659+
if not await self._validate_session(request, send):
660+
return False
661+
if not await self._validate_protocol_version(request, send):
662+
return False
663+
return True
664+
664665
async def _validate_session(self, request: Request, send: Send) -> bool:
665666
"""Validate the session ID in the request."""
666667
if not self.mcp_session_id:

0 commit comments

Comments
 (0)