Skip to content

Commit c0cf801

Browse files
committed
Add in_trio_run and in_trio_task
1 parent d7cb2fc commit c0cf801

File tree

7 files changed

+140
-1
lines changed

7 files changed

+140
-1
lines changed

docs/source/reference-lowlevel.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,33 @@ Global statistics
5656
.. autoclass:: RunStatistics()
5757

5858

59+
The current Trio context
60+
------------------------
61+
62+
There are two different types of contexts in :mod:`trio`. Here are the
63+
semantics presented as a handy table. Choose the right function for
64+
your needs.
65+
66+
+---------------------------------+-----------------------------------+------------------------------------+
67+
| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` |
68+
+=================================+===================================+====================================+
69+
| inside a running async function | `True` | `True` |
70+
+---------------------------------+-----------------------------------+------------------------------------+
71+
| without a running Trio loop | `False` | `False` |
72+
+---------------------------------+-----------------------------------+------------------------------------+
73+
| in a guest run's host loop | `True` | `False` |
74+
+---------------------------------+-----------------------------------+------------------------------------+
75+
| inside an instrument call | depends | depends |
76+
+---------------------------------+-----------------------------------+------------------------------------+
77+
| :func:`trio.to_thread.run_sync` | `False` | `False` |
78+
+---------------------------------+-----------------------------------+------------------------------------+
79+
| inside an abort function | `True` | `True` |
80+
+---------------------------------+-----------------------------------+------------------------------------+
81+
82+
.. function:: in_trio_run
83+
84+
.. function:: in_trio_task
85+
5986
The current clock
6087
-----------------
6188

src/trio/_core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
current_task,
4646
current_time,
4747
current_trio_token,
48+
in_trio_run,
49+
in_trio_task,
4850
notify_closing,
4951
open_nursery,
5052
remove_instrument,

src/trio/_core/_run.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2283,7 +2283,7 @@ def setup_runner(
22832283
# It wouldn't be *hard* to support nested calls to run(), but I can't
22842284
# think of a single good reason for it, so let's be conservative for
22852285
# now:
2286-
if hasattr(GLOBAL_RUN_CONTEXT, "runner"):
2286+
if in_trio_run():
22872287
raise RuntimeError("Attempted to call run() from inside a run()")
22882288

22892289
if clock is None:
@@ -2952,6 +2952,14 @@ async def checkpoint_if_cancelled() -> None:
29522952
task._cancel_points += 1
29532953

29542954

2955+
def in_trio_run() -> bool:
2956+
return hasattr(GLOBAL_RUN_CONTEXT, "runner")
2957+
2958+
2959+
def in_trio_task() -> bool:
2960+
return hasattr(GLOBAL_RUN_CONTEXT, "task")
2961+
2962+
29552963
if sys.platform == "win32":
29562964
from ._generated_io_windows import *
29572965
from ._io_windows import (

src/trio/_core/_tests/test_guest_mode.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,26 @@ async def synchronize() -> None:
264264
sniffio_library.name = None
265265

266266

267+
def test_guest_mode_trio_context_detection() -> None:
268+
def check(thing: bool) -> None:
269+
assert thing
270+
271+
assert not trio.lowlevel.in_trio_run()
272+
assert not trio.lowlevel.in_trio_task()
273+
274+
async def trio_main(in_host: InHost) -> None:
275+
for _ in range(2):
276+
assert trio.lowlevel.in_trio_run()
277+
assert trio.lowlevel.in_trio_task()
278+
279+
in_host(lambda: check(trio.lowlevel.in_trio_run()))
280+
in_host(lambda: check(not trio.lowlevel.in_trio_task()))
281+
282+
trivial_guest_run(trio_main)
283+
assert not trio.lowlevel.in_trio_run()
284+
assert not trio.lowlevel.in_trio_task()
285+
286+
267287
def test_warn_set_wakeup_fd_overwrite() -> None:
268288
assert signal.set_wakeup_fd(-1) == -1
269289

src/trio/_core/_tests/test_instrumentation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,52 @@ async def main() -> None:
266266
assert "task_exited" not in runner.instruments
267267

268268
_core.run(main)
269+
270+
271+
def test_instrument_call_trio_context() -> None:
272+
called = set()
273+
274+
class Instrument(_abc.Instrument):
275+
pass
276+
277+
hooks = {
278+
# category 1
279+
"after_io_wait": (True, False),
280+
"before_io_wait": (True, False),
281+
"before_run": (True, False),
282+
# category 2
283+
"after_run": (False, False),
284+
# category 3
285+
"before_task_step": (True, True),
286+
"after_task_step": (True, True),
287+
"task_exited": (True, True),
288+
# category 4
289+
"task_scheduled": (True, None),
290+
"task_spawned": (True, None),
291+
}
292+
for hook, val in hooks.items():
293+
294+
def h(
295+
self: Instrument,
296+
*args: object,
297+
hook: str = hook,
298+
val: tuple[bool | None, bool | None] = val,
299+
) -> None:
300+
fail_str = f"failed in {hook}"
301+
302+
if val[0] is not None:
303+
assert _core.in_trio_run() == val[0], fail_str
304+
if val[1] is not None:
305+
assert _core.in_trio_task() == val[1], fail_str
306+
called.add(hook)
307+
308+
setattr(Instrument, hook, h)
309+
310+
async def main() -> None:
311+
await _core.checkpoint()
312+
313+
async with _core.open_nursery() as nursery:
314+
nursery.start_soon(_core.checkpoint)
315+
316+
_core.run(main, instruments=[Instrument()])
317+
assert called == set(hooks)

src/trio/_core/_tests/test_run.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,3 +2855,34 @@ def run(self, fn: Callable[[], object]) -> object:
28552855

28562856
with mock.patch("trio._core._run.copy_context", return_value=Context()):
28572857
assert _count_context_run_tb_frames() == 1
2858+
2859+
2860+
@restore_unraisablehook()
2861+
def test_trio_context_detection() -> None:
2862+
assert not _core.in_trio_run()
2863+
assert not _core.in_trio_task()
2864+
2865+
def inner() -> None:
2866+
assert _core.in_trio_run()
2867+
assert _core.in_trio_task()
2868+
2869+
def sync_inner() -> None:
2870+
assert not _core.in_trio_run()
2871+
assert not _core.in_trio_task()
2872+
2873+
def inner_abort(_: object) -> _core.Abort:
2874+
assert _core.in_trio_run()
2875+
assert _core.in_trio_task()
2876+
return _core.Abort.SUCCEEDED
2877+
2878+
async def main() -> None:
2879+
assert _core.in_trio_run()
2880+
assert _core.in_trio_task()
2881+
2882+
inner()
2883+
2884+
await to_thread_run_sync(sync_inner)
2885+
with _core.CancelScope(deadline=_core.current_time() - 1):
2886+
await _core.wait_task_rescheduled(inner_abort)
2887+
2888+
_core.run(main)

src/trio/lowlevel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
currently_ki_protected as currently_ki_protected,
3838
disable_ki_protection as disable_ki_protection,
3939
enable_ki_protection as enable_ki_protection,
40+
in_trio_run as in_trio_run,
41+
in_trio_task as in_trio_task,
4042
notify_closing as notify_closing,
4143
permanently_detach_coroutine_object as permanently_detach_coroutine_object,
4244
reattach_detached_coroutine_object as reattach_detached_coroutine_object,

0 commit comments

Comments
 (0)