Skip to content

Commit d61b050

Browse files
authored
Merge pull request #2477 from harahu/run-types
Add some low-effort type annotations
2 parents af3d7d8 + fc4ed29 commit d61b050

File tree

1 file changed

+86
-57
lines changed

1 file changed

+86
-57
lines changed

trio/_core/_run.py

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,61 @@
1+
from __future__ import annotations
2+
3+
import enum
14
import functools
5+
import gc
26
import itertools
37
import random
48
import select
59
import sys
610
import threading
7-
import gc
11+
import warnings
812
from collections import deque
13+
from collections.abc import Callable
914
from contextlib import contextmanager
10-
import warnings
11-
import enum
12-
1315
from contextvars import copy_context
16+
from heapq import heapify, heappop, heappush
1417
from math import inf
1518
from time import perf_counter
16-
from typing import Callable, TYPE_CHECKING
17-
18-
from sniffio import current_async_library_cvar
19+
from typing import TYPE_CHECKING, Any, NoReturn, TypeVar
1920

2021
import attr
21-
from heapq import heapify, heappop, heappush
22-
from sortedcontainers import SortedDict
2322
from outcome import Error, Outcome, Value, capture
23+
from sniffio import current_async_library_cvar
24+
from sortedcontainers import SortedDict
2425

26+
# An unfortunate name collision here with trio._util.Final
27+
from typing_extensions import Final as FinalT
28+
29+
from .. import _core
30+
from .._util import Final, NoPublicConstructor, coroutine_or_error
31+
from ._asyncgens import AsyncGenerators
2532
from ._entry_queue import EntryQueue, TrioToken
26-
from ._exceptions import TrioInternalError, RunFinishedError, Cancelled
27-
from ._ki import (
28-
LOCALS_KEY_KI_PROTECTION_ENABLED,
29-
KIManager,
30-
enable_ki_protection,
31-
)
33+
from ._exceptions import Cancelled, RunFinishedError, TrioInternalError
34+
from ._instrumentation import Instruments
35+
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection
3236
from ._multierror import MultiError, concat_tb
37+
from ._thread_cache import start_thread_soon
3338
from ._traps import (
3439
Abort,
35-
wait_task_rescheduled,
36-
cancel_shielded_checkpoint,
3740
CancelShieldedCheckpoint,
3841
PermanentlyDetachCoroutineObject,
3942
WaitTaskRescheduled,
43+
cancel_shielded_checkpoint,
44+
wait_task_rescheduled,
4045
)
41-
from ._asyncgens import AsyncGenerators
42-
from ._thread_cache import start_thread_soon
43-
from ._instrumentation import Instruments
44-
from .. import _core
45-
from .._util import Final, NoPublicConstructor, coroutine_or_error
4646

4747
if sys.version_info < (3, 11):
4848
from exceptiongroup import BaseExceptionGroup
4949

50-
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD = 1000
50+
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
5151

52-
_NO_SEND = object()
52+
_NO_SEND: FinalT = object()
5353

54+
FnT = TypeVar("FnT", bound="Callable[..., Any]")
5455

5556
# Decorator to mark methods public. This does nothing by itself, but
5657
# trio/_tools/gen_exports.py looks for it.
57-
def _public(fn):
58+
def _public(fn: FnT) -> FnT:
5859
return fn
5960

6061

@@ -63,50 +64,71 @@ def _public(fn):
6364
# variable to True, and registers the Random instance _r for Hypothesis
6465
# to manage for each test case, which together should make Trio's task
6566
# scheduling loop deterministic. We have a test for that, of course.
66-
_ALLOW_DETERMINISTIC_SCHEDULING = False
67+
_ALLOW_DETERMINISTIC_SCHEDULING: FinalT = False
6768
_r = random.Random()
6869

6970

70-
# On CPython, Context.run() is implemented in C and doesn't show up in
71-
# tracebacks. On PyPy, it is implemented in Python and adds 1 frame to tracebacks.
72-
def _count_context_run_tb_frames():
73-
def function_with_unique_name_xyzzy():
74-
1 / 0
71+
def _count_context_run_tb_frames() -> int:
72+
"""Count implementation dependent traceback frames from Context.run()
73+
74+
On CPython, Context.run() is implemented in C and doesn't show up in
75+
tracebacks. On PyPy, it is implemented in Python and adds 1 frame to
76+
tracebacks.
77+
78+
Returns:
79+
int: Traceback frame count
80+
81+
"""
82+
83+
def function_with_unique_name_xyzzy() -> NoReturn:
84+
try:
85+
1 / 0
86+
except ZeroDivisionError:
87+
raise
88+
else: # pragma: no cover
89+
raise TrioInternalError(
90+
"A ZeroDivisionError should have been raised, but it wasn't."
91+
)
7592

7693
ctx = copy_context()
7794
try:
7895
ctx.run(function_with_unique_name_xyzzy)
7996
except ZeroDivisionError as exc:
8097
tb = exc.__traceback__
8198
# Skip the frame where we caught it
82-
tb = tb.tb_next
99+
tb = tb.tb_next # type: ignore[union-attr]
83100
count = 0
84-
while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy":
85-
tb = tb.tb_next
101+
while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": # type: ignore[union-attr]
102+
tb = tb.tb_next # type: ignore[union-attr]
86103
count += 1
87104
return count
105+
else: # pragma: no cover
106+
raise TrioInternalError(
107+
f"The purpose of {function_with_unique_name_xyzzy.__name__} is "
108+
"to raise a ZeroDivisionError, but it didn't."
109+
)
88110

89111

90-
CONTEXT_RUN_TB_FRAMES = _count_context_run_tb_frames()
112+
CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames()
91113

92114

93115
@attr.s(frozen=True, slots=True)
94116
class SystemClock:
95117
# Add a large random offset to our clock to ensure that if people
96118
# accidentally call time.perf_counter() directly or start comparing clocks
97119
# between different runs, then they'll notice the bug quickly:
98-
offset = attr.ib(factory=lambda: _r.uniform(10000, 200000))
120+
offset: float = attr.ib(factory=lambda: _r.uniform(10000, 200000))
99121

100-
def start_clock(self):
122+
def start_clock(self) -> None:
101123
pass
102124

103125
# In cPython 3, on every platform except Windows, perf_counter is
104126
# exactly the same as time.monotonic; and on Windows, it uses
105127
# QueryPerformanceCounter instead of GetTickCount64.
106-
def current_time(self):
128+
def current_time(self) -> float:
107129
return self.offset + perf_counter()
108130

109-
def deadline_to_sleep_time(self, deadline):
131+
def deadline_to_sleep_time(self, deadline: float) -> float:
110132
return deadline - self.current_time()
111133

112134

@@ -1119,7 +1141,7 @@ class Task(metaclass=NoPublicConstructor):
11191141
name = attr.ib()
11201142
# PEP 567 contextvars context
11211143
context = attr.ib()
1122-
_counter = attr.ib(init=False, factory=itertools.count().__next__)
1144+
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)
11231145

11241146
# Invariant:
11251147
# - for unscheduled tasks, _next_send_fn and _next_send are both None
@@ -1293,7 +1315,7 @@ class RunContext(threading.local):
12931315
task: Task
12941316

12951317

1296-
GLOBAL_RUN_CONTEXT = RunContext()
1318+
GLOBAL_RUN_CONTEXT: FinalT = RunContext()
12971319

12981320

12991321
@attr.s(frozen=True)
@@ -1380,7 +1402,7 @@ class Runner:
13801402
# Run-local values, see _local.py
13811403
_locals = attr.ib(factory=dict)
13821404

1383-
runq = attr.ib(factory=deque)
1405+
runq: deque[Task] = attr.ib(factory=deque)
13841406
tasks = attr.ib(factory=set)
13851407

13861408
deadlines = attr.ib(factory=Deadlines)
@@ -1957,8 +1979,8 @@ def run(
19571979
*args,
19581980
clock=None,
19591981
instruments=(),
1960-
restrict_keyboard_interrupt_to_checkpoints=False,
1961-
strict_exception_groups=False,
1982+
restrict_keyboard_interrupt_to_checkpoints: bool = False,
1983+
strict_exception_groups: bool = False,
19621984
):
19631985
"""Run a Trio-flavored async function, and return the result.
19641986
@@ -2063,11 +2085,11 @@ def start_guest_run(
20632085
run_sync_soon_threadsafe,
20642086
done_callback,
20652087
run_sync_soon_not_threadsafe=None,
2066-
host_uses_signal_set_wakeup_fd=False,
2088+
host_uses_signal_set_wakeup_fd: bool = False,
20672089
clock=None,
20682090
instruments=(),
2069-
restrict_keyboard_interrupt_to_checkpoints=False,
2070-
strict_exception_groups=False,
2091+
restrict_keyboard_interrupt_to_checkpoints: bool = False,
2092+
strict_exception_groups: bool = False,
20712093
):
20722094
"""Start a "guest" run of Trio on top of some other "host" event loop.
20732095
@@ -2147,14 +2169,19 @@ def my_done_callback(run_outcome):
21472169

21482170
# 24 hours is arbitrary, but it avoids issues like people setting timeouts of
21492171
# 10**20 and then getting integer overflows in the underlying system calls.
2150-
_MAX_TIMEOUT = 24 * 60 * 60
2172+
_MAX_TIMEOUT: FinalT = 24 * 60 * 60
21512173

21522174

21532175
# Weird quirk: this is written as a generator in order to support "guest
21542176
# mode", where our core event loop gets unrolled into a series of callbacks on
21552177
# the host loop. If you're doing a regular trio.run then this gets run
21562178
# straight through.
2157-
def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
2179+
def unrolled_run(
2180+
runner: Runner,
2181+
async_fn,
2182+
args,
2183+
host_uses_signal_set_wakeup_fd: bool = False,
2184+
):
21582185
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
21592186
__tracebackhide__ = True
21602187

@@ -2173,7 +2200,7 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
21732200
# here is our event loop:
21742201
while runner.tasks:
21752202
if runner.runq:
2176-
timeout = 0
2203+
timeout: float = 0
21772204
else:
21782205
deadline = runner.deadlines.next_deadline()
21792206
timeout = runner.clock.deadline_to_sleep_time(deadline)
@@ -2301,8 +2328,10 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False):
23012328
# frame we always remove, because it's this function
23022329
# catching it, and then in addition we remove however many
23032330
# more Context.run adds.
2304-
tb = task_exc.__traceback__.tb_next
2305-
for _ in range(CONTEXT_RUN_TB_FRAMES):
2331+
tb = task_exc.__traceback__
2332+
for _ in range(1 + CONTEXT_RUN_TB_FRAMES):
2333+
if tb is None:
2334+
break
23062335
tb = tb.tb_next
23072336
final_outcome = Error(task_exc.with_traceback(tb))
23082337
# Remove local refs so that e.g. cancelled coroutine locals
@@ -2397,7 +2426,7 @@ def started(self, value=None):
23972426
pass
23982427

23992428

2400-
TASK_STATUS_IGNORED = _TaskStatusIgnored()
2429+
TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored()
24012430

24022431

24032432
def current_task():
@@ -2493,16 +2522,16 @@ async def checkpoint_if_cancelled():
24932522

24942523

24952524
if sys.platform == "win32":
2496-
from ._io_windows import WindowsIOManager as TheIOManager
24972525
from ._generated_io_windows import *
2526+
from ._io_windows import WindowsIOManager as TheIOManager
24982527
elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")):
2499-
from ._io_epoll import EpollIOManager as TheIOManager
25002528
from ._generated_io_epoll import *
2529+
from ._io_epoll import EpollIOManager as TheIOManager
25012530
elif TYPE_CHECKING or hasattr(select, "kqueue"):
2502-
from ._io_kqueue import KqueueIOManager as TheIOManager
25032531
from ._generated_io_kqueue import *
2532+
from ._io_kqueue import KqueueIOManager as TheIOManager
25042533
else: # pragma: no cover
25052534
raise NotImplementedError("unsupported platform")
25062535

2507-
from ._generated_run import *
25082536
from ._generated_instrumentation import *
2537+
from ._generated_run import *

0 commit comments

Comments
 (0)