|
14 | 14 | from typing import TYPE_CHECKING |
15 | 15 |
|
16 | 16 | if TYPE_CHECKING: |
17 | | - from typing import Any, List, Callable, Dict, Union, Optional |
| 17 | + from typing import ( |
| 18 | + Any, |
| 19 | + List, |
| 20 | + Callable, |
| 21 | + Dict, |
| 22 | + Union, |
| 23 | + Optional, |
| 24 | + AsyncIterator, |
| 25 | + Iterator, |
| 26 | + ) |
18 | 27 | from uuid import UUID |
19 | 28 |
|
| 29 | + |
20 | 30 | try: |
21 | | - from langchain_core.messages import BaseMessage |
| 31 | + from langchain_core.messages import BaseMessage, MessageStreamEvent |
22 | 32 | from langchain_core.outputs import LLMResult |
23 | 33 | from langchain_core.callbacks import ( |
24 | 34 | manager, |
@@ -735,20 +745,44 @@ def new_stream(self, *args, **kwargs): |
735 | 745 | return f(self, *args, **kwargs) |
736 | 746 |
|
737 | 747 | # Create a span that will act as the parent for all callback-generated spans |
738 | | - with sentry_sdk.start_span( |
| 748 | + span = sentry_sdk.start_span( |
739 | 749 | op=OP.GEN_AI_INVOKE_AGENT, |
740 | 750 | name="AgentExecutor.stream", |
741 | 751 | origin=LangchainIntegration.origin, |
742 | | - ) as span: |
743 | | - span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent") |
| 752 | + ) |
| 753 | + span.__enter__() |
744 | 754 |
|
745 | | - if hasattr(self, "agent") and hasattr(self.agent, "llm"): |
746 | | - model_name = getattr(self.agent.llm, "model_name", None) or getattr( |
747 | | - self.agent.llm, "model", None |
748 | | - ) |
749 | | - if model_name: |
750 | | - span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name) |
| 755 | + span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent") |
751 | 756 |
|
752 | | - return f(self, *args, **kwargs) |
| 757 | + if hasattr(self, "agent") and hasattr(self.agent, "llm"): |
| 758 | + model_name = getattr(self.agent.llm, "model_name", None) or getattr( |
| 759 | + self.agent.llm, "model", None |
| 760 | + ) |
| 761 | + if model_name: |
| 762 | + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name) |
| 763 | + |
| 764 | + result = f(self, *args, **kwargs) |
| 765 | + old_iterator = result |
| 766 | + |
| 767 | + def new_iterator(): |
| 768 | + # type: () -> Iterator[MessageStreamEvent] |
| 769 | + for event in old_iterator: |
| 770 | + # import ipdb; ipdb.set_trace() |
| 771 | + yield event |
| 772 | + span.__exit__(None, None, None) |
| 773 | + |
| 774 | + async def new_iterator_async(): |
| 775 | + # type: () -> AsyncIterator[MessageStreamEvent] |
| 776 | + async for event in old_iterator: |
| 777 | + yield event |
| 778 | + |
| 779 | + span.__exit__(None, None, None) |
| 780 | + |
| 781 | + if str(type(result)) == "<class 'async_generator'>": |
| 782 | + result = new_iterator_async() |
| 783 | + else: |
| 784 | + result = new_iterator() |
| 785 | + |
| 786 | + return result |
753 | 787 |
|
754 | 788 | return new_stream |
0 commit comments