Skip to content

Commit 5ba9bd8

Browse files
committed
if the current_task().coro.cr_frame is in the stack ki_protection_enabled is current_task()._ki_protected
1 parent 408d1ae commit 5ba9bd8

File tree

3 files changed

+32
-35
lines changed

3 files changed

+32
-35
lines changed

src/trio/_core/_ki.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import attrs
1010

1111
from .._util import is_main_thread
12+
from ._run_context import GLOBAL_RUN_CONTEXT
1213

1314
if TYPE_CHECKING:
1415
import types
@@ -170,6 +171,16 @@ def legacy_isasyncgenfunction(
170171
# NB: according to the signal.signal docs, 'frame' can be None on entry to
171172
# this function:
172173
def ki_protection_enabled(frame: types.FrameType | None) -> bool:
174+
try:
175+
task = GLOBAL_RUN_CONTEXT.task
176+
except AttributeError:
177+
task_ki_protected = False
178+
task_frame = None
179+
else:
180+
task_ki_protected = task._ki_protected
181+
task_frame = task.coro.cr_frame
182+
del task
183+
173184
while frame is not None:
174185
try:
175186
v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code]
@@ -179,6 +190,8 @@ def ki_protection_enabled(frame: types.FrameType | None) -> bool:
179190
return bool(v)
180191
if frame.f_code.co_name == "__del__":
181192
return True
193+
if frame is task_frame:
194+
return task_ki_protected
182195
frame = frame.f_back
183196
return True
184197

src/trio/_core/_run.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import random
88
import select
99
import sys
10-
import threading
1110
import warnings
1211
from collections import deque
1312
from contextlib import AbstractAsyncContextManager, contextmanager, suppress
@@ -39,8 +38,9 @@
3938
from ._entry_queue import EntryQueue, TrioToken
4039
from ._exceptions import Cancelled, RunFinishedError, TrioInternalError
4140
from ._instrumentation import Instruments
42-
from ._ki import KIManager, disable_ki_protection, enable_ki_protection
41+
from ._ki import KIManager, enable_ki_protection
4342
from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER
43+
from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT
4444
from ._thread_cache import start_thread_soon
4545
from ._traps import (
4646
Abort,
@@ -83,7 +83,6 @@
8383
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
8484

8585
FnT = TypeVar("FnT", bound="Callable[..., Any]")
86-
T = TypeVar("T")
8786
RetT = TypeVar("RetT")
8887

8988

@@ -1559,14 +1558,6 @@ def raise_cancel() -> NoReturn:
15591558
################################################################
15601559

15611560

1562-
class RunContext(threading.local):
1563-
runner: Runner
1564-
task: Task
1565-
1566-
1567-
GLOBAL_RUN_CONTEXT: Final = RunContext()
1568-
1569-
15701561
@attrs.frozen
15711562
class RunStatistics:
15721563
"""An object containing run-loop-level debugging information.
@@ -1670,22 +1661,6 @@ def in_main_thread() -> None:
16701661
start_thread_soon(get_events, deliver)
16711662

16721663

1673-
@enable_ki_protection
1674-
def run_with_ki_protection_enabled(f: Callable[[T], RetT], v: T) -> RetT:
1675-
try:
1676-
return f(v)
1677-
finally:
1678-
del v # for the case where f is coro.throw() and v is a (Base)Exception
1679-
1680-
1681-
@disable_ki_protection
1682-
def run_with_ki_protection_disabled(f: Callable[[T], RetT], v: T) -> RetT:
1683-
try:
1684-
return f(v)
1685-
finally:
1686-
del v # for the case where f is coro.throw() and v is a (Base)Exception
1687-
1688-
16891664
@attrs.define(eq=False)
16901665
class Runner:
16911666
clock: Clock
@@ -2730,11 +2705,6 @@ def unrolled_run(
27302705

27312706
next_send_fn = task._next_send_fn
27322707
next_send = task._next_send
2733-
run_with = (
2734-
run_with_ki_protection_enabled
2735-
if task._ki_protected
2736-
else run_with_ki_protection_disabled
2737-
)
27382708
task._next_send_fn = task._next_send = None
27392709
final_outcome: Outcome[Any] | None = None
27402710
try:
@@ -2747,17 +2717,16 @@ def unrolled_run(
27472717
# https://github.com/python/cpython/issues/108668
27482718
# So now we send in the Outcome object and unwrap it on the
27492719
# other side.
2750-
msg = task.context.run(run_with, next_send_fn, next_send)
2720+
msg = task.context.run(next_send_fn, next_send)
27512721
except StopIteration as stop_iteration:
27522722
final_outcome = Value(stop_iteration.value)
27532723
except BaseException as task_exc:
27542724
# Store for later, removing uninteresting top frames: 1
27552725
# frame we always remove, because it's this function
2756-
# another is the run_with
27572726
# catching it, and then in addition we remove however many
27582727
# more Context.run adds.
27592728
tb = task_exc.__traceback__
2760-
for _ in range(2 + CONTEXT_RUN_TB_FRAMES):
2729+
for _ in range(1 + CONTEXT_RUN_TB_FRAMES):
27612730
if tb is not None: # pragma: no branch
27622731
tb = tb.tb_next
27632732
final_outcome = Error(task_exc.with_traceback(tb))

src/trio/_core/_run_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from __future__ import annotations
2+
3+
import threading
4+
from typing import TYPE_CHECKING, Final
5+
6+
if TYPE_CHECKING:
7+
from ._run import Runner, Task
8+
9+
10+
class RunContext(threading.local):
11+
runner: Runner
12+
task: Task
13+
14+
15+
GLOBAL_RUN_CONTEXT: Final = RunContext()

0 commit comments

Comments
 (0)