Skip to content

Commit 10e00e7

Browse files
committed
Typecheck
1 parent e42dbf5 commit 10e00e7

File tree

13 files changed

+55
-42
lines changed

13 files changed

+55
-42
lines changed

src/mcp/client/sse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from contextlib import asynccontextmanager
3-
from typing import Any, Union
3+
from typing import Any
44
from urllib.parse import urljoin, urlparse
55

66
import anyio
@@ -26,7 +26,7 @@ async def sse_client(
2626
headers: dict[str, Any] | None = None,
2727
timeout: float = 5,
2828
sse_read_timeout: float = 60 * 5,
29-
auth: Union[AuthSession, OAuthClient, None] = None,
29+
auth: AuthSession | OAuthClient | None = None,
3030
):
3131
"""
3232
Client transport for SSE.

src/mcp/server/auth/handlers/authorize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import dataclass
3-
from typing import Literal
3+
from typing import Any, Literal
44
from urllib.parse import urlencode, urlparse, urlunparse
55

66
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
@@ -70,13 +70,13 @@ def best_effort_extract_string(
7070
return None
7171

7272

73-
class AnyHttpUrlModel(RootModel):
73+
class AnyHttpUrlModel(RootModel[AnyHttpUrl]):
7474
root: AnyHttpUrl
7575

7676

7777
@dataclass
7878
class AuthorizationHandler:
79-
provider: OAuthServerProvider
79+
provider: OAuthServerProvider[Any, Any, Any]
8080

8181
async def handle(self, request: Request) -> Response:
8282
# implements authorization requests for grant_type=code;

src/mcp/server/auth/handlers/register.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import secrets
22
import time
33
from dataclasses import dataclass
4+
from typing import Any
45
from uuid import uuid4
56

67
from pydantic import BaseModel, RootModel, ValidationError
@@ -18,7 +19,7 @@
1819
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
1920

2021

21-
class RegistrationRequest(RootModel):
22+
class RegistrationRequest(RootModel[OAuthClientMetadata]):
2223
# this wrapper is a no-op; it's just to separate out the types exposed to the
2324
# provider from what we use in the HTTP handler
2425
root: OAuthClientMetadata
@@ -31,7 +32,7 @@ class RegistrationErrorResponse(BaseModel):
3132

3233
@dataclass
3334
class RegistrationHandler:
34-
provider: OAuthServerProvider
35+
provider: OAuthServerProvider[Any, Any, Any]
3536
options: ClientRegistrationOptions
3637

3738
async def handle(self, request: Request) -> Response:

src/mcp/server/auth/handlers/revoke.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from functools import partial
3-
from typing import Literal
3+
from typing import Any, Literal
44

55
from pydantic import BaseModel, ValidationError
66
from starlette.requests import Request
@@ -35,7 +35,7 @@ class RevocationErrorResponse(BaseModel):
3535

3636
@dataclass
3737
class RevocationHandler:
38-
provider: OAuthServerProvider
38+
provider: OAuthServerProvider[Any, Any, Any]
3939
client_authenticator: ClientAuthenticator
4040

4141
async def handle(self, request: Request) -> Response:

src/mcp/server/auth/handlers/token.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hashlib
33
import time
44
from dataclasses import dataclass
5-
from typing import Annotated, Literal
5+
from typing import Annotated, Any, Literal
66

77
from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError
88
from starlette.requests import Request
@@ -44,7 +44,14 @@ class RefreshTokenRequest(BaseModel):
4444
client_secret: str | None = None
4545

4646

47-
class TokenRequest(RootModel):
47+
class TokenRequest(
48+
RootModel[
49+
Annotated[
50+
AuthorizationCodeRequest | RefreshTokenRequest,
51+
Field(discriminator="grant_type"),
52+
]
53+
]
54+
):
4855
root: Annotated[
4956
AuthorizationCodeRequest | RefreshTokenRequest,
5057
Field(discriminator="grant_type"),
@@ -61,7 +68,7 @@ class TokenErrorResponse(BaseModel):
6168
error_uri: AnyHttpUrl | None = None
6269

6370

64-
class TokenSuccessResponse(RootModel):
71+
class TokenSuccessResponse(RootModel[OAuthToken]):
6572
# this is just a wrapper over OAuthToken; the only reason we do this
6673
# is to have some separation between the HTTP response type, and the
6774
# type returned by the provider
@@ -70,7 +77,7 @@ class TokenSuccessResponse(RootModel):
7077

7178
@dataclass
7279
class TokenHandler:
73-
provider: OAuthServerProvider
80+
provider: OAuthServerProvider[Any, Any, Any]
7481
client_authenticator: ClientAuthenticator
7582

7683
def response(self, obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse):

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Any, Callable
2+
from typing import Any
33

44
from starlette.authentication import (
55
AuthCredentials,
@@ -8,7 +8,7 @@
88
)
99
from starlette.exceptions import HTTPException
1010
from starlette.requests import HTTPConnection
11-
from starlette.types import Scope
11+
from starlette.types import Receive, Scope, Send
1212

1313
from mcp.server.auth.provider import AccessToken, OAuthServerProvider
1414

@@ -29,7 +29,7 @@ class BearerAuthBackend(AuthenticationBackend):
2929

3030
def __init__(
3131
self,
32-
provider: OAuthServerProvider,
32+
provider: OAuthServerProvider[Any, Any, Any],
3333
):
3434
self.provider = provider
3535

@@ -72,7 +72,7 @@ def __init__(self, app: Any, required_scopes: list[str]):
7272
self.app = app
7373
self.required_scopes = required_scopes
7474

75-
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
75+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
7676
auth_credentials = scope.get("auth")
7777

7878
for required_scope in self.required_scopes:

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from typing import Any
23

34
from mcp.server.auth.provider import OAuthServerProvider
45
from mcp.shared.auth import OAuthClientInformationFull
@@ -20,7 +21,7 @@ class ClientAuthenticator:
2021
logic is skipped.
2122
"""
2223

23-
def __init__(self, provider: OAuthServerProvider):
24+
def __init__(self, provider: OAuthServerProvider[Any, Any, Any]):
2425
"""
2526
Initialize the dependency.
2627

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Callable
1+
from collections.abc import Callable
2+
from typing import Any
23

34
from pydantic import AnyHttpUrl
45
from starlette.routing import Route

src/mcp/server/fastmcp/server.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import inspect
66
import json
77
import re
8-
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
8+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
99
from contextlib import (
1010
AbstractAsyncContextManager,
1111
asynccontextmanager,
1212
)
1313
from itertools import chain
14-
from typing import Any, Awaitable, Generic, Literal
14+
from typing import Any, Generic, Literal
1515

1616
import anyio
1717
import pydantic_core
@@ -22,10 +22,10 @@
2222
from sse_starlette import EventSourceResponse
2323
from starlette.applications import Starlette
2424
from starlette.authentication import requires
25+
from starlette.middleware import Middleware
2526
from starlette.middleware.authentication import AuthenticationMiddleware
2627
from starlette.requests import Request
2728
from starlette.responses import Response
28-
from starlette.middleware import Middleware
2929
from starlette.routing import Mount, Route
3030

3131
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
@@ -491,7 +491,6 @@ def custom_route(
491491
name: str | None = None,
492492
include_in_schema: bool = True,
493493
):
494-
495494
def decorator(
496495
func: Callable[[Request], Awaitable[Response]],
497496
) -> Callable[[Request], Awaitable[Response]]:
@@ -541,7 +540,7 @@ async def handle_sse(request: Request) -> EventSourceResponse:
541540
async with sse.connect_sse(
542541
request.scope,
543542
request.receive,
544-
request._send # type: ignore[reportPrivateUsage]
543+
request._send, # type: ignore[reportPrivateUsage]
545544
) as streams:
546545
await self._mcp_server.run(
547546
streams[0],
@@ -586,7 +585,9 @@ async def handle_sse(request: Request) -> EventSourceResponse:
586585

587586
routes.append(
588587
Route(
589-
self.settings.sse_path, endpoint=requires(required_scopes)(handle_sse), methods=["GET"]
588+
self.settings.sse_path,
589+
endpoint=requires(required_scopes)(handle_sse),
590+
methods=["GET"],
590591
)
591592
)
592593
routes.append(
@@ -754,9 +755,9 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
754755
Returns:
755756
The resource content as either text or bytes
756757
"""
757-
assert self._fastmcp is not None, (
758-
"Context is not available outside of a request"
759-
)
758+
assert (
759+
self._fastmcp is not None
760+
), "Context is not available outside of a request"
760761
return await self._fastmcp.read_resource(uri)
761762

762763
async def log(

src/mcp/server/streaming_asgi_transport.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
"""
1010

1111
import typing
12-
from typing import Any, Dict, Tuple
12+
from typing import Any, cast
1313

1414
import anyio
1515
import anyio.abc
1616
import anyio.streams.memory
1717
from httpx._models import Request, Response
1818
from httpx._transports.base import AsyncBaseTransport
1919
from httpx._types import AsyncByteStream
20+
from starlette.types import ASGIApp, Receive, Scope, Send
2021

2122

2223
class StreamingASGITransport(AsyncBaseTransport):
@@ -42,11 +43,11 @@ class StreamingASGITransport(AsyncBaseTransport):
4243

4344
def __init__(
4445
self,
45-
app: typing.Callable,
46+
app: ASGIApp,
4647
task_group: anyio.abc.TaskGroup,
4748
raise_app_exceptions: bool = True,
4849
root_path: str = "",
49-
client: Tuple[str, int] = ("127.0.0.1", 123),
50+
client: tuple[str, int] = ("127.0.0.1", 123),
5051
) -> None:
5152
self.app = app
5253
self.raise_app_exceptions = raise_app_exceptions
@@ -88,13 +89,15 @@ async def handle_async_request(
8889
initial_response_ready = anyio.Event()
8990

9091
# Synchronization for streaming response
91-
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream(100)
92+
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[
93+
dict[str, Any]
94+
](100)
9295
content_send_channel, content_receive_channel = (
9396
anyio.create_memory_object_stream[bytes](100)
9497
)
9598

9699
# ASGI callables.
97-
async def receive() -> Dict[str, Any]:
100+
async def receive() -> dict[str, Any]:
98101
nonlocal request_complete
99102

100103
if request_complete:
@@ -108,15 +111,18 @@ async def receive() -> Dict[str, Any]:
108111
return {"type": "http.request", "body": b"", "more_body": False}
109112
return {"type": "http.request", "body": body, "more_body": True}
110113

111-
async def send(message: Dict[str, Any]) -> None:
114+
async def send(message: dict[str, Any]) -> None:
112115
nonlocal status_code, response_headers, response_started
113116

114117
await asgi_send_channel.send(message)
115118

116119
# Start the ASGI application in a separate task
117120
async def run_app() -> None:
118121
try:
119-
await self.app(scope, receive, send)
122+
# Cast the receive and send functions to the ASGI types
123+
await self.app(
124+
cast(Scope, scope), cast(Receive, receive), cast(Send, send)
125+
)
120126
except Exception:
121127
if self.raise_app_exceptions:
122128
raise

0 commit comments

Comments
 (0)