Skip to content

Commit b53f627

Browse files
committed
Fix type checking
1 parent f710a0b commit b53f627

File tree

1 file changed

+19
-10
lines changed
  • pydantic_graph/pydantic_graph/beta

1 file changed

+19
-10
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import types
1212
import uuid
1313
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Sequence
14-
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
14+
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
1515
from dataclasses import dataclass, field
1616
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast, get_args, get_origin, overload
1717

@@ -518,7 +518,7 @@ def task_group(self) -> TaskGroup:
518518
async def iter_graph( # noqa C901
519519
self, first_task: GraphTask
520520
) -> AsyncGenerator[EndMarker[OutputT] | Sequence[GraphTask], EndMarker[OutputT] | Sequence[GraphTask]]:
521-
try:
521+
with _unwrap_exception_groups():
522522
async with self.iter_stream_sender, create_task_group() as self._task_group:
523523
try:
524524
# Fire off the first task
@@ -649,14 +649,9 @@ async def iter_graph( # noqa C901
649649
except GeneratorExit:
650650
return
651651

652-
except ExceptionGroup as e: # pyright: ignore[reportUnknownVariableType]
653-
# TODO: Handle this better in some way?
654-
raise e.exceptions[0] # pyright: ignore[reportUnknownMemberType]
655-
656-
if not self.task_group.cancel_scope.cancel_called:
657-
raise RuntimeError( # pragma: no cover
658-
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
659-
)
652+
raise RuntimeError( # pragma: no cover
653+
'Graph run completed, but no result was produced. This is either a bug in the graph or a bug in the graph runner.'
654+
)
660655

661656
async def _finish_task(self, task_id: TaskID, keep_cancel_scope: bool = False) -> None:
662657
if not keep_cancel_scope:
@@ -889,3 +884,17 @@ def _is_any_iterable(x: Any) -> TypeGuard[Iterable[Any]]:
889884

890885
def _is_any_async_iterable(x: Any) -> TypeGuard[AsyncIterable[Any]]:
891886
return isinstance(x, AsyncIterable)
887+
888+
889+
@contextmanager
890+
def _unwrap_exception_groups():
891+
# I need to use a helper function for this because I can't figure out a way to get pyright
892+
# to type-check the ExceptionGroup catching in both 3.13 and 3.10 without emitting type errors in one;
893+
# if I try to ignore them in one, I get unnecessary-type-ignore errors in the other
894+
if TYPE_CHECKING:
895+
yield
896+
else:
897+
try:
898+
yield
899+
except ExceptionGroup as e:
900+
raise e.exceptions[0]

0 commit comments

Comments
 (0)