Skip to content

Commit 60f8a54

Browse files
committed
refactor: Make types.py strictly typechecked.
We are now making the types module typecheck in strict mode. This is mostly achieved by passing the correct types to the generic type variables. We ran into one specific issue with `JSONRPCRequest` and `JSONRPCNotification`. Both are generic classes that take a `dict[str, Any]` as params and just a plain string as method. However, since the TypeVar for `RequestT` and `NotificationT` are bound to `RequestParams` and `NotificationParams` respectively we get into a type issue. There are two ways of solving this: 1. Widen the bound by allowing explicitly for `dict[str, Any]` 2. Make JSONRPCRequest and JSONRPCNotificaiton not part of the type hierarchy with Request and Notification roots. It felt most naturally to keep JSONRPCRequest/JSONRPCNotification part of the type hierarchy and allow for general passing of dict[str, Any]. This now typechecks.
1 parent ae77772 commit 60f8a54

File tree

4 files changed

+66
-34
lines changed

4 files changed

+66
-34
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ include = ["src/mcp", "tests"]
7777
venvPath = "."
7878
venv = ".venv"
7979
strict = ["src/mcp/**/*.py"]
80-
exclude = ["src/mcp/types.py"]
8180

8281
[tool.ruff.lint]
8382
select = ["E", "F", "I", "UP"]

src/mcp/types.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ class Meta(BaseModel):
6464
"""
6565

6666

67-
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
68-
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
67+
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
68+
NotificationParamsT = TypeVar(
69+
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
70+
)
6971
MethodT = TypeVar("MethodT", bound=str)
7072

7173

@@ -113,15 +115,16 @@ class PaginatedResult(Result):
113115
"""
114116

115117

116-
class JSONRPCRequest(Request):
118+
class JSONRPCRequest(Request[dict[str, Any] | None, str]):
117119
"""A request that expects a response."""
118120

119121
jsonrpc: Literal["2.0"]
120122
id: RequestId
123+
method: str
121124
params: dict[str, Any] | None = None
122125

123126

124-
class JSONRPCNotification(Notification):
127+
class JSONRPCNotification(Notification[dict[str, Any] | None, str]):
125128
"""A notification which does not expect a response."""
126129

127130
jsonrpc: Literal["2.0"]
@@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams):
277280
model_config = ConfigDict(extra="allow")
278281

279282

280-
class InitializeRequest(Request):
283+
class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]):
281284
"""
282285
This request is sent from the client to the server when it first connects, asking it
283286
to begin initialization.
@@ -298,7 +301,9 @@ class InitializeResult(Result):
298301
"""Instructions describing how to use the server and its features."""
299302

300303

301-
class InitializedNotification(Notification):
304+
class InitializedNotification(
305+
Notification[NotificationParams | None, Literal["notifications/initialized"]]
306+
):
302307
"""
303308
This notification is sent from the client to the server after initialization has
304309
finished.
@@ -308,7 +313,7 @@ class InitializedNotification(Notification):
308313
params: NotificationParams | None = None
309314

310315

311-
class PingRequest(Request):
316+
class PingRequest(Request[RequestParams | None, Literal["ping"]]):
312317
"""
313318
A ping, issued by either the server or the client, to check that the other party is
314319
still alive.
@@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams):
336341
model_config = ConfigDict(extra="allow")
337342

338343

339-
class ProgressNotification(Notification):
344+
class ProgressNotification(
345+
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
346+
):
340347
"""
341348
An out-of-band notification used to inform the receiver of a progress update for a
342349
long-running request.
@@ -346,7 +353,9 @@ class ProgressNotification(Notification):
346353
params: ProgressNotificationParams
347354

348355

349-
class ListResourcesRequest(PaginatedRequest):
356+
class ListResourcesRequest(
357+
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
358+
):
350359
"""Sent from the client to request a list of resources the server has."""
351360

352361
method: Literal["resources/list"]
@@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult):
408417
resources: list[Resource]
409418

410419

411-
class ListResourceTemplatesRequest(PaginatedRequest):
420+
class ListResourceTemplatesRequest(
421+
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
422+
):
412423
"""Sent from the client to request a list of resource templates the server has."""
413424

414425
method: Literal["resources/templates/list"]
@@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams):
432443
model_config = ConfigDict(extra="allow")
433444

434445

435-
class ReadResourceRequest(Request):
446+
class ReadResourceRequest(
447+
Request[ReadResourceRequestParams, Literal["resources/read"]]
448+
):
436449
"""Sent from the client to the server, to read a specific resource URI."""
437450

438451
method: Literal["resources/read"]
@@ -472,7 +485,11 @@ class ReadResourceResult(Result):
472485
contents: list[TextResourceContents | BlobResourceContents]
473486

474487

475-
class ResourceListChangedNotification(Notification):
488+
class ResourceListChangedNotification(
489+
Notification[
490+
NotificationParams | None, Literal["notifications/resources/list_changed"]
491+
]
492+
):
476493
"""
477494
An optional notification from the server to the client, informing it that the list
478495
of resources it can read from has changed.
@@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams):
493510
model_config = ConfigDict(extra="allow")
494511

495512

496-
class SubscribeRequest(Request):
513+
class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]):
497514
"""
498515
Sent from the client to request resources/updated notifications from the server
499516
whenever a particular resource changes.
@@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams):
511528
model_config = ConfigDict(extra="allow")
512529

513530

514-
class UnsubscribeRequest(Request):
531+
class UnsubscribeRequest(
532+
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
533+
):
515534
"""
516535
Sent from the client to request cancellation of resources/updated notifications from
517536
the server.
@@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams):
532551
model_config = ConfigDict(extra="allow")
533552

534553

535-
class ResourceUpdatedNotification(Notification):
554+
class ResourceUpdatedNotification(
555+
Notification[
556+
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
557+
]
558+
):
536559
"""
537560
A notification from the server to the client, informing it that a resource has
538561
changed and may need to be read again.
@@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification):
542565
params: ResourceUpdatedNotificationParams
543566

544567

545-
class ListPromptsRequest(PaginatedRequest):
568+
class ListPromptsRequest(
569+
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
570+
):
546571
"""Sent from the client to request a list of prompts and prompt templates."""
547572

548573
method: Literal["prompts/list"]
@@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams):
589614
model_config = ConfigDict(extra="allow")
590615

591616

592-
class GetPromptRequest(Request):
617+
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
593618
"""Used by the client to get a prompt provided by the server."""
594619

595620
method: Literal["prompts/get"]
@@ -659,7 +684,11 @@ class GetPromptResult(Result):
659684
messages: list[PromptMessage]
660685

661686

662-
class PromptListChangedNotification(Notification):
687+
class PromptListChangedNotification(
688+
Notification[
689+
NotificationParams | None, Literal["notifications/prompts/list_changed"]
690+
]
691+
):
663692
"""
664693
An optional notification from the server to the client, informing it that the list
665694
of prompts it offers has changed.
@@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification):
669698
params: NotificationParams | None = None
670699

671700

672-
class ListToolsRequest(PaginatedRequest):
701+
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
673702
"""Sent from the client to request a list of tools the server has."""
674703

675704
method: Literal["tools/list"]
@@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams):
702731
model_config = ConfigDict(extra="allow")
703732

704733

705-
class CallToolRequest(Request):
734+
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
706735
"""Used by the client to invoke a tool provided by the server."""
707736

708737
method: Literal["tools/call"]
@@ -716,7 +745,9 @@ class CallToolResult(Result):
716745
isError: bool = False
717746

718747

719-
class ToolListChangedNotification(Notification):
748+
class ToolListChangedNotification(
749+
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
750+
):
720751
"""
721752
An optional notification from the server to the client, informing it that the list
722753
of tools it offers has changed.
@@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams):
739770
model_config = ConfigDict(extra="allow")
740771

741772

742-
class SetLevelRequest(Request):
773+
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
743774
"""A request from the client to the server, to enable or adjust logging."""
744775

745776
method: Literal["logging/setLevel"]
@@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams):
761792
model_config = ConfigDict(extra="allow")
762793

763794

764-
class LoggingMessageNotification(Notification):
795+
class LoggingMessageNotification(
796+
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
797+
):
765798
"""Notification of a log message passed from server to client."""
766799

767800
method: Literal["notifications/message"]
@@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams):
856889
model_config = ConfigDict(extra="allow")
857890

858891

859-
class CreateMessageRequest(Request):
892+
class CreateMessageRequest(
893+
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
894+
):
860895
"""A request from the server to sample an LLM via the client."""
861896

862897
method: Literal["sampling/createMessage"]
@@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams):
913948
model_config = ConfigDict(extra="allow")
914949

915950

916-
class CompleteRequest(Request):
951+
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
917952
"""A request from the client to the server, to ask for completion options."""
918953

919954
method: Literal["completion/complete"]
@@ -944,7 +979,7 @@ class CompleteResult(Result):
944979
completion: Completion
945980

946981

947-
class ListRootsRequest(Request):
982+
class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
948983
"""
949984
Sent from the server to request a list of root URIs from the client. Roots allow
950985
servers to ask for specific directories or files to operate on. A common example
@@ -987,7 +1022,9 @@ class ListRootsResult(Result):
9871022
roots: list[Root]
9881023

9891024

990-
class RootsListChangedNotification(Notification):
1025+
class RootsListChangedNotification(
1026+
Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]]
1027+
):
9911028
"""
9921029
A notification from the client to the server, informing it that the list of
9931030
roots has changed.

tests/shared/test_sse.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]:
138138
time.sleep(0.1)
139139
attempt += 1
140140
else:
141-
raise RuntimeError(
142-
f"Server failed to start after {max_attempts} attempts"
143-
)
141+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
144142

145143
yield
146144

tests/shared/test_ws.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]:
134134
time.sleep(0.1)
135135
attempt += 1
136136
else:
137-
raise RuntimeError(
138-
f"Server failed to start after {max_attempts} attempts"
139-
)
137+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
140138

141139
yield
142140

0 commit comments

Comments
 (0)