Skip to content

Commit 1212872

Browse files
ChrisPC-39Sebastiancrivetimihai
authored
Static typing (#725)
* Mypy admin.py and streamablehttp_transport.py Signed-off-by: Sebastian <[email protected]> * Mypy admin.py Signed-off-by: Sebastian <[email protected]> * Mypy admin.py Signed-off-by: Sebastian <[email protected]> * Cherry-pick commits Signed-off-by: Sebastian <[email protected]> * Mypy mcpgateway/transports/streamablehttp_transport.py Signed-off-by: Sebastian <[email protected]> * Mypy mcpgateway/transports/streamablehttp_transport.py part2 Signed-off-by: Sebastian <[email protected]> * Small fixes to session_registry.py and streamablehttp_transport.py Signed-off-by: Sebastian <[email protected]> * Mypy cache/session_registry.py Signed-off-by: Sebastian <[email protected]> * Small fixes in mcpgateway/cache/session_registry.py Signed-off-by: Sebastian <[email protected]> * Mypy for session_registry Signed-off-by: Sebastian <[email protected]> * Mypy streamablehttp_transport after rebase Signed-off-by: Sebastian <[email protected]> * Mypy admin.py after rebase Signed-off-by: Sebastian <[email protected]> * Run linters Signed-off-by: Sebastian <[email protected]> * Fix unit tests Signed-off-by: Sebastian <[email protected]> * Pytest fixes Signed-off-by: Sebastian <[email protected]> * Flake8 Signed-off-by: Sebastian <[email protected]> * Apply pre-commit formatting fixes Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Sebastian <[email protected]> Signed-off-by: Mihai Criveti <[email protected]> Co-authored-by: Sebastian <[email protected]> Co-authored-by: Mihai Criveti <[email protected]>
1 parent ee15ba7 commit 1212872

File tree

9 files changed

+183
-190
lines changed

9 files changed

+183
-190
lines changed

docs/docs/manage/bulk-import.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,4 +235,4 @@ The endpoint provides detailed error information for each failed tool:
235235
- Use the bulk import for initial setup and migrations
236236
- Export existing tools first to understand the schema
237237
- Test with a small subset before importing hundreds of tools
238-
- Keep your import files in version control for reproducibility
238+
- Keep your import files in version control for reproducibility

mcpgateway/admin.py

Lines changed: 144 additions & 153 deletions
Large diffs are not rendered by default.

mcpgateway/cache/session_registry.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@
6969
from mcpgateway.utils.retry_manager import ResilientHttpClient
7070

7171
# Initialize logging service first
72-
logging_service = LoggingService()
72+
logging_service: LoggingService = LoggingService()
7373
logger = logging_service.get_logger(__name__)
7474

75-
tool_service = ToolService()
76-
resource_service = ResourceService()
77-
prompt_service = PromptService()
75+
tool_service: ToolService = ToolService()
76+
resource_service: ResourceService = ResourceService()
77+
prompt_service: PromptService = PromptService()
7878

7979
try:
8080
# Third-Party
@@ -423,7 +423,7 @@ async def add_session(self, session_id: str, transport: SSETransport) -> None:
423423
# Store session in database
424424
try:
425425

426-
def _db_add():
426+
def _db_add() -> None:
427427
"""Store session record in the database.
428428
429429
Creates a new SessionRecord entry in the database for tracking
@@ -520,7 +520,7 @@ async def get_session(self, session_id: str) -> Any:
520520
elif self._backend == "database":
521521
try:
522522

523-
def _db_check():
523+
def _db_check() -> bool:
524524
"""Check if a session exists in the database.
525525
526526
Queries the SessionRecord table to determine if a session with
@@ -614,7 +614,7 @@ async def remove_session(self, session_id: str) -> None:
614614
elif self._backend == "database":
615615
try:
616616

617-
def _db_remove():
617+
def _db_remove() -> None:
618618
"""Delete session record from the database.
619619
620620
Removes the SessionRecord entry with the specified session_id
@@ -649,7 +649,7 @@ def _db_remove():
649649

650650
logger.info(f"Removed session: {session_id}")
651651

652-
async def broadcast(self, session_id: str, message: dict) -> None:
652+
async def broadcast(self, session_id: str, message: Dict[str, Any]) -> None:
653653
"""Broadcast a message to a session.
654654
655655
Sends a message to the specified session. The behavior depends on the backend:
@@ -691,7 +691,7 @@ async def broadcast(self, session_id: str, message: dict) -> None:
691691
else:
692692
msg_json = json.dumps(str(message))
693693

694-
self._session_message = {"session_id": session_id, "message": msg_json}
694+
self._session_message: Dict[str, Any] = {"session_id": session_id, "message": msg_json}
695695

696696
elif self._backend == "redis":
697697
try:
@@ -710,7 +710,7 @@ async def broadcast(self, session_id: str, message: dict) -> None:
710710
else:
711711
msg_json = json.dumps(str(message))
712712

713-
def _db_add():
713+
def _db_add() -> None:
714714
"""Store message in the database for inter-process communication.
715715
716716
Creates a new SessionMessageRecord entry containing the session_id
@@ -791,7 +791,7 @@ def get_session_sync(self, session_id: str) -> Any:
791791
async def respond(
792792
self,
793793
server_id: Optional[str],
794-
user: json,
794+
user: Dict[str, Any],
795795
session_id: str,
796796
base_url: str,
797797
) -> None:
@@ -830,7 +830,7 @@ async def respond(
830830
# if self._session_message:
831831
transport = self.get_session_sync(session_id)
832832
if transport:
833-
message = json.loads(self._session_message.get("message"))
833+
message = json.loads(str(self._session_message.get("message")))
834834
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
835835

836836
elif self._backend == "redis":
@@ -857,7 +857,7 @@ async def respond(
857857

858858
elif self._backend == "database":
859859

860-
def _db_read_session(session_id):
860+
def _db_read_session(session_id: str) -> SessionRecord:
861861
"""Check if session still exists in the database.
862862
863863
Queries the SessionRecord table to verify that the session
@@ -892,7 +892,7 @@ def _db_read_session(session_id):
892892
finally:
893893
db_session.close()
894894

895-
def _db_read(session_id):
895+
def _db_read(session_id: str) -> SessionMessageRecord:
896896
"""Read pending message for a session from the database.
897897
898898
Retrieves the first (oldest) unprocessed message for the given
@@ -927,7 +927,7 @@ def _db_read(session_id):
927927
finally:
928928
db_session.close()
929929

930-
def _db_remove(session_id, message):
930+
def _db_remove(session_id: str, message: str) -> None:
931931
"""Remove processed message from the database.
932932
933933
Deletes a specific message record after it has been successfully
@@ -960,7 +960,7 @@ def _db_remove(session_id, message):
960960
finally:
961961
db_session.close()
962962

963-
async def message_check_loop(session_id):
963+
async def message_check_loop(session_id: str) -> None:
964964
"""Poll database for messages and deliver to local transport.
965965
966966
Continuously checks the database for new messages directed to
@@ -1042,7 +1042,7 @@ async def _db_cleanup_task(self) -> None:
10421042
while True:
10431043
try:
10441044
# Clean up expired sessions every 5 minutes
1045-
def _db_cleanup():
1045+
def _db_cleanup() -> int:
10461046
"""Remove expired sessions from the database.
10471047
10481048
Deletes all SessionRecord entries that haven't been accessed
@@ -1093,7 +1093,7 @@ def _db_cleanup():
10931093
continue
10941094

10951095
# Refresh session in database
1096-
def _refresh_session(session_id=session_id):
1096+
def _refresh_session(session_id: str = session_id) -> bool:
10971097
"""Update session's last accessed timestamp in the database.
10981098
10991099
Refreshes the last_accessed field for an active session to
@@ -1184,7 +1184,7 @@ async def _memory_cleanup_task(self) -> None:
11841184
await asyncio.sleep(300) # Sleep longer on error
11851185

11861186
# Handle initialize logic
1187-
async def handle_initialize_logic(self, body: dict) -> InitializeResult:
1187+
async def handle_initialize_logic(self, body: Dict[str, Any]) -> InitializeResult:
11881188
"""Process MCP protocol initialization request.
11891189
11901190
Validates the protocol version and returns server capabilities and information.
@@ -1240,14 +1240,13 @@ async def handle_initialize_logic(self, body: dict) -> InitializeResult:
12401240
resources={"subscribe": True, "listChanged": True},
12411241
tools={"listChanged": True},
12421242
logging={},
1243-
roots={"listChanged": True},
1244-
sampling={},
1243+
# roots={"listChanged": True}
12451244
),
12461245
serverInfo=Implementation(name=settings.app_name, version=__version__),
12471246
instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."),
12481247
)
12491248

1250-
async def generate_response(self, message: json, transport: SSETransport, server_id: Optional[str], user: dict, base_url: str):
1249+
async def generate_response(self, message: Dict[str, Any], transport: SSETransport, server_id: Optional[str], user: Dict[str, Any], base_url: str) -> None:
12511250
"""Generate and send response for incoming MCP protocol message.
12521251
12531252
Processes MCP protocol messages and generates appropriate responses based on

mcpgateway/db.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
# Standard
2525
from datetime import datetime, timezone
26-
from typing import Any, Dict, List, Optional
26+
from typing import Any, Dict, Generator, List, Optional
2727
import uuid
2828

2929
# Third-Party
@@ -32,7 +32,7 @@
3232
from sqlalchemy.event import listen
3333
from sqlalchemy.exc import SQLAlchemyError
3434
from sqlalchemy.ext.hybrid import hybrid_property
35-
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, sessionmaker
35+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Session, sessionmaker
3636
from sqlalchemy.orm.attributes import get_history
3737

3838
# First-Party
@@ -1276,7 +1276,7 @@ def validate_prompt_schema(mapper, connection, target):
12761276
listen(Prompt, "before_update", validate_prompt_schema)
12771277

12781278

1279-
def get_db():
1279+
def get_db() -> Generator[Session, Any, None]:
12801280
"""
12811281
Dependency to get database session.
12821282

mcpgateway/services/logging_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class LoggingService:
9494
- Logger name tracking
9595
"""
9696

97-
def __init__(self):
97+
def __init__(self) -> None:
9898
"""Initialize logging service."""
9999
self._level = LogLevel.INFO
100100
self._subscribers: List[asyncio.Queue] = []

mcpgateway/services/resource_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class ResourceService:
109109
- Active/inactive status management
110110
"""
111111

112-
def __init__(self):
112+
def __init__(self) -> None:
113113
"""Initialize the resource service."""
114114
self._event_subscribers: Dict[str, List[asyncio.Queue]] = {}
115115
self._template_cache: Dict[str, ResourceTemplate] = {}

mcpgateway/services/root_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class RootService:
3939
- Root permissions and access control
4040
"""
4141

42-
def __init__(self):
42+
def __init__(self) -> None:
4343
"""Initialize root service."""
4444
self._roots: Dict[str, Root] = {}
4545
self._subscribers: List[asyncio.Queue] = []

mcpgateway/services/tool_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class ToolService:
151151
- Active/inactive tool management.
152152
"""
153153

154-
def __init__(self):
154+
def __init__(self) -> None:
155155
"""Initialize the tool service.
156156
157157
Examples:

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import contextvars
3636
from dataclasses import dataclass
3737
import re
38-
from typing import List, Union
38+
from typing import Any, AsyncGenerator, List, Union
3939
from uuid import uuid4
4040

4141
# Third-Party
@@ -45,6 +45,7 @@
4545
from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
4646
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4747
from mcp.types import JSONRPCMessage
48+
from sqlalchemy.orm import Session
4849
from starlette.datastructures import Headers
4950
from starlette.responses import JSONResponse
5051
from starlette.status import HTTP_401_UNAUTHORIZED
@@ -62,11 +63,11 @@
6263
logger = logging_service.get_logger(__name__)
6364

6465
# Initialize ToolService and MCP Server
65-
tool_service = ToolService()
66-
mcp_app = Server("mcp-streamable-http-stateless")
66+
tool_service: ToolService = ToolService()
67+
mcp_app: Server[Any] = Server("mcp-streamable-http-stateless")
6768

68-
server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default=None)
69-
request_headers_var = contextvars.ContextVar("request_headers", default={})
69+
server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id")
70+
request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={})
7071

7172
# ------------------------------ Event store ------------------------------
7273

@@ -305,7 +306,7 @@ async def replay_events_after(
305306

306307

307308
@asynccontextmanager
308-
async def get_db():
309+
async def get_db() -> AsyncGenerator[Session, Any]:
309310
"""
310311
Asynchronous context manager for database sessions.
311312
@@ -536,7 +537,7 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen
536537
# ------------------------- Authentication for /mcp routes ------------------------------
537538

538539

539-
async def streamable_http_auth(scope, receive, send):
540+
async def streamable_http_auth(scope: Any, receive: Any, send: Any) -> bool:
540541
"""
541542
Perform authentication check in middleware context (ASGI scope).
542543
@@ -584,6 +585,8 @@ async def streamable_http_auth(scope, receive, send):
584585
if scheme.lower() == "bearer" and credentials:
585586
token = credentials
586587
try:
588+
if token is None:
589+
raise Exception()
587590
await verify_credentials(token)
588591
except Exception:
589592
response = JSONResponse(

0 commit comments

Comments
 (0)