|
13 | 13 | from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast |
14 | 14 |
|
15 | 15 | import anyio |
| 16 | +from anyio.streams.memory import MemoryObjectSendStream |
16 | 17 | from opentelemetry.trace import Tracer |
17 | 18 | from typing_extensions import TypeVar, assert_never |
18 | 19 |
|
@@ -599,7 +600,7 @@ async def _run(): # noqa: C901 |
599 | 600 | if text: |
600 | 601 | try: |
601 | 602 | self._next_node = await self._handle_text_response( |
602 | | - ctx, text, text_processor |
| 603 | + ctx, text, text_processor, send_stream |
603 | 604 | ) |
604 | 605 | return |
605 | 606 | except ToolRetryError: |
@@ -668,8 +669,9 @@ async def _run(): # noqa: C901 |
668 | 669 |
|
669 | 670 | if text_processor := output_schema.text_processor: |
670 | 671 | if text: |
671 | | - # TODO (DouweM): This could call an output function that yields custom events, but we're not in an event stream here? |
672 | | - self._next_node = await self._handle_text_response(ctx, text, text_processor) |
| 672 | + self._next_node = await self._handle_text_response( |
| 673 | + ctx, text, text_processor, send_stream |
| 674 | + ) |
673 | 675 | return |
674 | 676 | alternatives.insert(0, 'return text') |
675 | 677 |
|
@@ -737,8 +739,10 @@ async def _handle_text_response( |
737 | 739 | ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], |
738 | 740 | text: str, |
739 | 741 | text_processor: _output.BaseOutputProcessor[NodeRunEndT], |
| 742 | + event_stream: MemoryObjectSendStream[_messages.CustomEvent], |
740 | 743 | ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: |
741 | 744 | run_context = build_run_context(ctx) |
| 745 | + run_context = replace(run_context, event_stream=event_stream) |
742 | 746 |
|
743 | 747 | result_data = await text_processor.process(text, run_context) |
744 | 748 |
|
|
0 commit comments