Skip to content

Commit 7760bc7

Browse files
committed
Draft of improved H3 for hypercorn.
1 parent 3fbd5f2 commit 7760bc7

File tree

5 files changed

+133
-17
lines changed

5 files changed

+133
-17
lines changed

src/hypercorn/asyncio/task_group.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Awaitable, Callable, Optional
77

88
from ..config import Config
9-
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
9+
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer
1010

1111
try:
1212
from asyncio import TaskGroup as AsyncioTaskGroup
@@ -33,6 +33,44 @@ async def _handle(
3333
await send(None)
3434

3535

36+
LONG_SLEEP = 86400.0
37+
38+
class AsyncioTimer(Timer):
39+
def __init__(self, action: Callable) -> None:
40+
self._action = action
41+
self._done = False
42+
self._wake_up = asyncio.Condition()
43+
self._when: Optional[float] = None
44+
45+
async def schedule(self, when: Optional[float]) -> None:
46+
self._when = when
47+
async with self._wake_up:
48+
self._wake_up.notify()
49+
50+
async def stop(self) -> None:
51+
self._done = True
52+
async with self._wake_up:
53+
self._wake_up.notify()
54+
55+
async def _wait_for_wake_up(self) -> None:
56+
async with self._wake_up:
57+
await self._wake_up.wait()
58+
59+
async def run(self) -> None:
60+
while not self._done:
61+
if self._when is not None and asyncio.get_event_loop().time() >= self._when:
62+
self._when = None
63+
await self._action()
64+
if self._when is not None:
65+
timeout = max(self._when - asyncio.get_event_loop().time(), 0.0)
66+
else:
67+
timeout = LONG_SLEEP
68+
if not self._done:
69+
try:
70+
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
71+
except TimeoutError:
72+
pass
73+
3674
class TaskGroup:
3775
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
3876
self._loop = loop
@@ -66,6 +104,11 @@ def _call_soon(func: Callable, *args: Any) -> Any:
66104
def spawn(self, func: Callable, *args: Any) -> None:
67105
self._task_group.create_task(func(*args))
68106

107+
def create_timer(self, action: Callable) -> Timer:
108+
timer = AsyncioTimer(action)
109+
self._task_group.create_task(timer.run())
110+
return timer
111+
69112
async def __aenter__(self) -> "TaskGroup":
70113
await self._task_group.__aenter__()
71114
return self

src/hypercorn/protocol/quic.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .h3 import H3Protocol
2323
from ..config import Config
2424
from ..events import Closed, Event, RawData
25-
from ..typing import AppWrapper, TaskGroup, WorkerContext
25+
from ..typing import AppWrapper, TaskGroup, WorkerContext, Timer
2626

2727

2828
class QuicProtocol:
@@ -40,6 +40,7 @@ def __init__(
4040
self.context = context
4141
self.connections: Dict[bytes, QuicConnection] = {}
4242
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
43+
self.timers: Dict[QuicConnection, Timer] = {}
4344
self.send = send
4445
self.server = server
4546
self.task_group = task_group
@@ -82,10 +83,12 @@ async def handle(self, event: Event) -> None:
8283
)
8384
self.connections[header.destination_cid] = connection
8485
self.connections[connection.host_cid] = connection
86+
# This partial() needs python >= 3.8
87+
self.timers[connection] = self.task_group.create_timer(partial(self._timeout, connection))
8588

8689
if connection is not None:
8790
connection.receive_datagram(event.data, event.address, now=self.context.time())
88-
await self._handle_events(connection, event.address)
91+
await self._wake_up_timer(connection)
8992
elif isinstance(event, Closed):
9093
pass
9194

@@ -99,7 +102,16 @@ async def _handle_events(
99102
event = connection.next_event()
100103
while event is not None:
101104
if isinstance(event, ConnectionTerminated):
102-
pass
105+
await self.timers[connection].stop()
106+
del self.timers[connection]
107+
# XXXRTH This is not the speediest! Better would be tracking
108+
# assigned ids in a set.
109+
prune = []
110+
for tcid, tconn in self.connections.items():
111+
if tconn == connection:
112+
prune.append(tcid)
113+
for tcid in prune:
114+
del self.connections[tcid]
103115
elif isinstance(event, ProtocolNegotiated):
104116
self.http_connections[connection] = H3Protocol(
105117
self.app,
@@ -109,7 +121,7 @@ async def _handle_events(
109121
client,
110122
self.server,
111123
connection,
112-
partial(self.send_all, connection),
124+
partial(self._wake_up_timer, connection),
113125
)
114126
elif isinstance(event, ConnectionIdIssued):
115127
self.connections[event.connection_id] = connection
@@ -121,15 +133,20 @@ async def _handle_events(
121133

122134
event = connection.next_event()
123135

136+
async def _wake_up_timer(self, connection: QuicConnection):
137+
# When new output is send, or new input is received, we
138+
# fire the timer right away so we update our state.
139+
timer = self.timers.get(connection)
140+
if timer is not None:
141+
await timer.schedule(0.0)
142+
143+
async def _timeout(self, connection: QuicConnection):
144+
now = self.context.time()
145+
when = connection.get_timer()
146+
if when is not None and now > when:
147+
connection.handle_timer(now)
148+
await self._handle_events(connection, None)
124149
await self.send_all(connection)
125-
126-
timer = connection.get_timer()
150+
timer = self.timers.get(connection)
127151
if timer is not None:
128-
self.task_group.spawn(self._handle_timer, timer, connection)
129-
130-
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
131-
wait = max(0, timer - self.context.time())
132-
await self.context.sleep(wait)
133-
if connection._close_at is not None:
134-
connection.handle_timer(now=self.context.time())
135-
await self._handle_events(connection, None)
152+
await timer.schedule(connection.get_timer())

src/hypercorn/trio/task_group.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import trio
88

99
from ..config import Config
10-
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
10+
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer
1111

1212
if sys.version_info < (3, 11):
1313
from exceptiongroup import BaseExceptionGroup
@@ -39,6 +39,40 @@ async def _handle(
3939
await send(None)
4040

4141

42+
LONG_SLEEP = 86400.0
43+
44+
class TrioTimer(Timer):
45+
def __init__(self, action: Callable) -> None:
46+
self._action = action
47+
self._done = False
48+
self._wake_up = trio.Condition()
49+
self._when: Optional[float] = None
50+
51+
async def schedule(self, when: Optional[float]) -> None:
52+
self._when = when
53+
async with self._wake_up:
54+
self._wake_up.notify()
55+
56+
async def stop(self) -> None:
57+
self._done = True
58+
async with self._wake_up:
59+
self._wake_up.notify()
60+
61+
async def run(self) -> None:
62+
while not self._done:
63+
if self._when is not None and trio.current_time() >= self._when:
64+
self._when = None
65+
await self._action()
66+
if self._when is not None:
67+
timeout = max(self._when - trio.current_time(), 0.0)
68+
else:
69+
timeout = LONG_SLEEP
70+
if not self._done:
71+
with trio.move_on_after(timeout):
72+
async with self._wake_up:
73+
await self._wake_up.wait()
74+
75+
4276
class TaskGroup:
4377
def __init__(self) -> None:
4478
self._nursery: Optional[trio._core._run.Nursery] = None
@@ -67,6 +101,11 @@ async def spawn_app(
67101
def spawn(self, func: Callable, *args: Any) -> None:
68102
self._nursery.start_soon(func, *args)
69103

104+
def create_timer(self, action: Callable) -> Timer:
105+
timer = TrioTimer(action)
106+
self._nursery.start_soon(timer.run)
107+
return timer
108+
70109
async def __aenter__(self) -> TaskGroup:
71110
self._nursery_manager = trio.open_nursery()
72111
self._nursery = await self._nursery_manager.__aenter__()

src/hypercorn/trio/worker_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Type, Union
3+
from typing import Awaitable, Optional, Type, Union
44

55
import trio
66

src/hypercorn/typing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,20 @@ def is_set(self) -> bool:
288288
...
289289

290290

291+
class Timer:
292+
def __init__(self, action: Callable) -> None:
293+
...
294+
295+
async def schedule(self, when: float) -> None:
296+
...
297+
298+
async def stop(self) -> None:
299+
...
300+
301+
async def run(self) -> None:
302+
...
303+
304+
291305
class WorkerContext(Protocol):
292306
event_class: Type[Event]
293307
terminate: Event
@@ -318,6 +332,9 @@ async def spawn_app(
318332
def spawn(self, func: Callable, *args: Any) -> None:
319333
...
320334

335+
def create_timer(self, action: Callable) -> Timer:
336+
...
337+
321338
async def __aenter__(self) -> TaskGroup:
322339
...
323340

0 commit comments

Comments
 (0)