Skip to content

Commit b324601

Browse files
michgurDouweM
andauthored
Fix pydantic_graph.beta.GraphRun GeneratorExit handling (#3525)
Co-authored-by: Douwe Maan <[email protected]>
1 parent f14411f commit b324601

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

pydantic_graph/pydantic_graph/beta/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
from pydantic_graph.nodes import BaseNode, End
4444

4545
if sys.version_info < (3, 11):
46-
from exceptiongroup import ExceptionGroup as ExceptionGroup # pragma: lax no cover
46+
from exceptiongroup import BaseExceptionGroup as BaseExceptionGroup # pragma: lax no cover
4747
else:
48-
ExceptionGroup = ExceptionGroup # pragma: lax no cover
48+
BaseExceptionGroup = BaseExceptionGroup # pragma: lax no cover
4949

5050
if TYPE_CHECKING:
5151
from pydantic_graph.beta.mermaid import StateDiagramDirection
@@ -970,7 +970,7 @@ def _unwrap_exception_groups():
970970
else:
971971
try:
972972
yield
973-
except ExceptionGroup as e:
973+
except BaseExceptionGroup as e:
974974
exception = e.exceptions[0]
975975
if exception.__cause__ is None:
976976
# bizarrely, this prevents recursion errors when formatting the exception for logfire

tests/graph/beta/test_graph_execution.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,20 @@ async def add_three(ctx: StepContext[None, None, int]) -> list[int]:
380380
graph = g.build()
381381
result = await graph.run()
382382
assert sorted(result) == [11, 12, 13]
383+
384+
385+
async def test_early_termination_from_nested_generator():
386+
"""Test that a generator wrapping an iteration can be terminated early."""
387+
g = GraphBuilder()
388+
g.add_edge(g.start_node, g.end_node)
389+
graph = g.build()
390+
391+
async def stream_graph():
392+
async with graph.iter() as run:
393+
async for node in run: # pragma: no branch
394+
yield node
395+
396+
gen = stream_graph()
397+
async for _ in gen: # pragma: no branch
398+
break
399+
await gen.aclose()

0 commit comments

Comments
 (0)