Skip to content

Commit 341ad92

Browse files
committed
fix utc datetime import, circular dependency, and add type hints
1 parent cd744e5 commit 341ad92

File tree

5 files changed

+42
-22
lines changed

5 files changed

+42
-22
lines changed

src/mcp/server/lowlevel/experimental.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def decorator(
7070
logger.debug("Registering handler for ListTasksRequest")
7171
wrapper = create_call_wrapper(func, ListTasksRequest)
7272

73-
async def handler(req: ListTasksRequest):
73+
async def handler(req: ListTasksRequest) -> ServerResult:
7474
result = await wrapper(req)
7575
return ServerResult(result)
7676

@@ -79,17 +79,23 @@ async def handler(req: ListTasksRequest):
7979

8080
return decorator
8181

82-
def get_task(self):
82+
def get_task(
83+
self,
84+
) -> Callable[
85+
[Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]]
86+
]:
8387
"""Register a handler for getting task status.
8488
8589
WARNING: This API is experimental and may change without notice.
8690
"""
8791

88-
def decorator(func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]]):
92+
def decorator(
93+
func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]],
94+
) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]:
8995
logger.debug("Registering handler for GetTaskRequest")
9096
wrapper = create_call_wrapper(func, GetTaskRequest)
9197

92-
async def handler(req: GetTaskRequest):
98+
async def handler(req: GetTaskRequest) -> ServerResult:
9399
result = await wrapper(req)
94100
return ServerResult(result)
95101

@@ -98,17 +104,24 @@ async def handler(req: GetTaskRequest):
98104

99105
return decorator
100106

101-
def get_task_result(self):
107+
def get_task_result(
108+
self,
109+
) -> Callable[
110+
[Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]],
111+
Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
112+
]:
102113
"""Register a handler for getting task results/payload.
103114
104115
WARNING: This API is experimental and may change without notice.
105116
"""
106117

107-
def decorator(func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]):
118+
def decorator(
119+
func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
120+
) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]:
108121
logger.debug("Registering handler for GetTaskPayloadRequest")
109122
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
110123

111-
async def handler(req: GetTaskPayloadRequest):
124+
async def handler(req: GetTaskPayloadRequest) -> ServerResult:
112125
result = await wrapper(req)
113126
return ServerResult(result)
114127

@@ -117,17 +130,24 @@ async def handler(req: GetTaskPayloadRequest):
117130

118131
return decorator
119132

120-
def cancel_task(self):
133+
def cancel_task(
134+
self,
135+
) -> Callable[
136+
[Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]],
137+
Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
138+
]:
121139
"""Register a handler for cancelling tasks.
122140
123141
WARNING: This API is experimental and may change without notice.
124142
"""
125143

126-
def decorator(func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]):
144+
def decorator(
145+
func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
146+
) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]:
127147
logger.debug("Registering handler for CancelTaskRequest")
128148
wrapper = create_call_wrapper(func, CancelTaskRequest)
129149

130-
async def handler(req: CancelTaskRequest):
150+
async def handler(req: CancelTaskRequest) -> ServerResult:
131151
result = await wrapper(req)
132152
return ServerResult(result)
133153

src/mcp/shared/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from typing_extensions import TypeVar
55

6-
from mcp import McpError
6+
from mcp.shared.exceptions import McpError
77
from mcp.shared.session import BaseSession
88
from mcp.types import (
99
METHOD_NOT_FOUND,

src/mcp/shared/experimental/tasks/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections.abc import AsyncIterator, Awaitable, Callable
66
from contextlib import asynccontextmanager
7-
from datetime import UTC, datetime
7+
from datetime import datetime, timezone
88
from typing import TYPE_CHECKING
99
from uuid import uuid4
1010

@@ -57,7 +57,7 @@ def create_task_state(
5757
return Task(
5858
taskId=task_id or generate_task_id(),
5959
status="working",
60-
createdAt=datetime.now(UTC),
60+
createdAt=datetime.now(timezone.utc),
6161
ttl=metadata.ttl,
6262
pollInterval=500, # Default 500ms poll interval
6363
)

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass, field
12-
from datetime import UTC, datetime, timedelta
12+
from datetime import datetime, timedelta, timezone
1313

1414
from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal
1515
from mcp.shared.experimental.tasks.store import TaskStore
@@ -51,13 +51,13 @@ def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
5151
"""Calculate expiry time from TTL in milliseconds."""
5252
if ttl_ms is None:
5353
return None
54-
return datetime.now(UTC) + timedelta(milliseconds=ttl_ms)
54+
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
5555

5656
def _is_expired(self, stored: StoredTask) -> bool:
5757
"""Check if a task has expired."""
5858
if stored.expires_at is None:
5959
return False
60-
return datetime.now(UTC) >= stored.expires_at
60+
return datetime.now(timezone.utc) >= stored.expires_at
6161

6262
def _cleanup_expired(self) -> None:
6363
"""Remove all expired tasks. Called lazily during access operations."""

tests/experimental/tasks/server/test_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for server-side task support (handlers, capabilities, integration)."""
22

3-
from datetime import UTC, datetime
3+
from datetime import datetime, timezone
44
from typing import Any
55

66
import anyio
@@ -54,14 +54,14 @@ async def test_list_tasks_handler() -> None:
5454
Task(
5555
taskId="task-1",
5656
status="working",
57-
createdAt=datetime.now(UTC),
57+
createdAt=datetime.now(timezone.utc),
5858
ttl=60000,
5959
pollInterval=1000,
6060
),
6161
Task(
6262
taskId="task-2",
6363
status="completed",
64-
createdAt=datetime.now(UTC),
64+
createdAt=datetime.now(timezone.utc),
6565
ttl=60000,
6666
pollInterval=1000,
6767
),
@@ -92,7 +92,7 @@ async def handle_get_task(request: GetTaskRequest) -> GetTaskResult:
9292
return GetTaskResult(
9393
taskId=request.params.taskId,
9494
status="working",
95-
createdAt=datetime.now(UTC),
95+
createdAt=datetime.now(timezone.utc),
9696
ttl=60000,
9797
pollInterval=1000,
9898
)
@@ -140,7 +140,7 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult:
140140
return CancelTaskResult(
141141
taskId=request.params.taskId,
142142
status="cancelled",
143-
createdAt=datetime.now(UTC),
143+
createdAt=datetime.now(timezone.utc),
144144
ttl=60000,
145145
)
146146

@@ -174,7 +174,7 @@ async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult:
174174
return CancelTaskResult(
175175
taskId=request.params.taskId,
176176
status="cancelled",
177-
createdAt=datetime.now(UTC),
177+
createdAt=datetime.now(timezone.utc),
178178
ttl=None,
179179
)
180180

0 commit comments

Comments
 (0)