|
29 | 29 | from ...testing import wait_all_tasks_blocked
|
30 | 30 |
|
31 | 31 | if TYPE_CHECKING:
|
32 |
| - from collections.abc import AsyncIterator, Callable, Iterator |
| 32 | + from collections.abc import ( |
| 33 | + AsyncGenerator, |
| 34 | + AsyncIterator, |
| 35 | + Callable, |
| 36 | + Generator, |
| 37 | + Iterator, |
| 38 | + ) |
33 | 39 |
|
34 | 40 | from ..._core import Abort, RaiseCancelT
|
35 | 41 |
|
@@ -631,3 +637,64 @@ def __eq__(self, other: object) -> bool:
|
631 | 637 | del a
|
632 | 638 | gc_collect_harder()
|
633 | 639 | assert data_copy
|
| 640 | + |
| 641 | + |
| 642 | +@_core.enable_ki_protection |
| 643 | +async def _protected_async_gen_fn() -> AsyncGenerator[None, None]: |
| 644 | + return |
| 645 | + yield |
| 646 | + |
| 647 | + |
| 648 | +@_core.enable_ki_protection |
| 649 | +async def _protected_async_fn() -> None: |
| 650 | + pass |
| 651 | + |
| 652 | + |
| 653 | +@_core.enable_ki_protection |
| 654 | +def _protected_gen_fn() -> Generator[None, None, None]: |
| 655 | + return |
| 656 | + yield |
| 657 | + |
| 658 | + |
| 659 | +@_core.disable_ki_protection |
| 660 | +async def _unprotected_async_gen_fn() -> AsyncGenerator[None, None]: |
| 661 | + return |
| 662 | + yield |
| 663 | + |
| 664 | + |
| 665 | +@_core.disable_ki_protection |
| 666 | +async def _unprotected_async_fn() -> None: |
| 667 | + pass |
| 668 | + |
| 669 | + |
| 670 | +@_core.disable_ki_protection |
| 671 | +def _unprotected_gen_fn() -> Generator[None, None, None]: |
| 672 | + return |
| 673 | + yield |
| 674 | + |
| 675 | + |
| 676 | +def _consume_function_for_coverage(fn: Callable[..., object]) -> None: |
| 677 | + result = fn() |
| 678 | + if inspect.isasyncgen(result): |
| 679 | + with pytest.raises(StopAsyncIteration): |
| 680 | + result.asend(None).send(None) |
| 681 | + return |
| 682 | + |
| 683 | + assert inspect.isgenerator(result) or inspect.iscoroutine(result) |
| 684 | + with pytest.raises(StopIteration): |
| 685 | + result.send(None) |
| 686 | + |
| 687 | + |
| 688 | +def test_enable_disable_ki_protection_passes_on_inspect_flags() -> None: |
| 689 | + assert inspect.isasyncgenfunction(_protected_async_gen_fn) |
| 690 | + _consume_function_for_coverage(_protected_async_gen_fn) |
| 691 | + assert inspect.iscoroutinefunction(_protected_async_fn) |
| 692 | + _consume_function_for_coverage(_protected_async_fn) |
| 693 | + assert inspect.isgeneratorfunction(_protected_gen_fn) |
| 694 | + _consume_function_for_coverage(_protected_gen_fn) |
| 695 | + assert inspect.isasyncgenfunction(_unprotected_async_gen_fn) |
| 696 | + _consume_function_for_coverage(_unprotected_async_gen_fn) |
| 697 | + assert inspect.iscoroutinefunction(_unprotected_async_fn) |
| 698 | + _consume_function_for_coverage(_unprotected_async_fn) |
| 699 | + assert inspect.isgeneratorfunction(_unprotected_gen_fn) |
| 700 | + _consume_function_for_coverage(_unprotected_gen_fn) |
0 commit comments