Skip to content

Commit 0c80de3

Browse files
committed
add tests for passing on inspect flags
1 parent 97a9eb2 commit 0c80de3

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

src/trio/_core/_tests/test_ki.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
from ...testing import wait_all_tasks_blocked
3030

3131
if TYPE_CHECKING:
32-
from collections.abc import AsyncIterator, Callable, Iterator
32+
from collections.abc import (
33+
AsyncGenerator,
34+
AsyncIterator,
35+
Callable,
36+
Generator,
37+
Iterator,
38+
)
3339

3440
from ..._core import Abort, RaiseCancelT
3541

@@ -631,3 +637,64 @@ def __eq__(self, other: object) -> bool:
631637
del a
632638
gc_collect_harder()
633639
assert data_copy
640+
641+
642+
@_core.enable_ki_protection
643+
async def _protected_async_gen_fn() -> AsyncGenerator[None, None]:
644+
return
645+
yield
646+
647+
648+
@_core.enable_ki_protection
649+
async def _protected_async_fn() -> None:
650+
pass
651+
652+
653+
@_core.enable_ki_protection
654+
def _protected_gen_fn() -> Generator[None, None, None]:
655+
return
656+
yield
657+
658+
659+
@_core.disable_ki_protection
660+
async def _unprotected_async_gen_fn() -> AsyncGenerator[None, None]:
661+
return
662+
yield
663+
664+
665+
@_core.disable_ki_protection
666+
async def _unprotected_async_fn() -> None:
667+
pass
668+
669+
670+
@_core.disable_ki_protection
671+
def _unprotected_gen_fn() -> Generator[None, None, None]:
672+
return
673+
yield
674+
675+
676+
def _consume_function_for_coverage(fn: Callable[..., object]) -> None:
677+
result = fn()
678+
if inspect.isasyncgen(result):
679+
with pytest.raises(StopAsyncIteration):
680+
result.asend(None).send(None)
681+
return
682+
683+
assert inspect.isgenerator(result) or inspect.iscoroutine(result)
684+
with pytest.raises(StopIteration):
685+
result.send(None)
686+
687+
688+
def test_enable_disable_ki_protection_passes_on_inspect_flags() -> None:
689+
assert inspect.isasyncgenfunction(_protected_async_gen_fn)
690+
_consume_function_for_coverage(_protected_async_gen_fn)
691+
assert inspect.iscoroutinefunction(_protected_async_fn)
692+
_consume_function_for_coverage(_protected_async_fn)
693+
assert inspect.isgeneratorfunction(_protected_gen_fn)
694+
_consume_function_for_coverage(_protected_gen_fn)
695+
assert inspect.isasyncgenfunction(_unprotected_async_gen_fn)
696+
_consume_function_for_coverage(_unprotected_async_gen_fn)
697+
assert inspect.iscoroutinefunction(_unprotected_async_fn)
698+
_consume_function_for_coverage(_unprotected_async_fn)
699+
assert inspect.isgeneratorfunction(_unprotected_gen_fn)
700+
_consume_function_for_coverage(_unprotected_gen_fn)

0 commit comments

Comments
 (0)