|
11 | 11 | import types |
12 | 12 | import uuid |
13 | 13 | from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence |
14 | | -from contextlib import AbstractContextManager, ExitStack, asynccontextmanager |
| 14 | +from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager |
15 | 15 | from dataclasses import dataclass, field |
16 | 16 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload |
17 | 17 |
|
@@ -518,7 +518,7 @@ def task_group(self) -> TaskGroup: |
518 | 518 | async def iter_graph( # noqa C901 |
519 | 519 | self, first_task: GraphTask |
520 | 520 | ) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]: |
521 | | - try: |
| 521 | + with _unwrap_exception_groups(): |
522 | 522 | async with self.iter_stream_sender, create_task_group() as self._task_group: |
523 | 523 | try: |
524 | 524 | # Fire off the first task |
@@ -649,14 +649,9 @@ async def iter_graph( # noqa C901 |
649 | 649 | except GeneratorExit: |
650 | 650 | return |
651 | 651 |
|
652 | | - except ExceptionGroup as e: # pyright: ignore[reportUnknownVariableType] |
653 | | - # TODO: Handle this better in some way? |
654 | | - raise e.exceptions[0] # pyright: ignore[reportUnknownMemberType] |
655 | | - |
656 | | - if not self.task_group.cancel_scope.cancel_called: |
657 | | - raise RuntimeError( # pragma: no cover |
658 | | - 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.' |
659 | | - ) |
| 652 | + raise RuntimeError( # pragma: no cover |
| 653 | + 'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.' |
| 654 | + ) |
660 | 655 |
|
661 | 656 | async def _finish_task(self, task_id: TaskID, keep_cancel_scope: bool = False) -> None: |
662 | 657 | if not keep_cancel_scope: |
@@ -889,3 +884,17 @@ def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]: |
889 | 884 |
|
890 | 885 | def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]: |
891 | 886 | return isinstance(x, AsyncIterable) |
| 887 | + |
| 888 | + |
| 889 | +@contextmanager |
| 890 | +def _unwrap_exception_groups(): |
| 891 | + # I need to use a helper function for this because I can't figure out a way to get pyright |
| 892 | + # to type-check the ExceptionGroup catching in both 3.13 and 3.10 without emitting type errors in one; |
| 893 | + # if I try to ignore them in one, I get unnecessary-type-ignore errors in the other |
| 894 | + if TYPE_CHECKING: |
| 895 | + yield |
| 896 | + else: |
| 897 | + try: |
| 898 | + yield |
| 899 | + except ExceptionGroup as e: |
| 900 | + raise e.exceptions[0] |
0 commit comments