Skip to content

Commit 9a68a32

Browse files
committed
nesting streaming agent correctly
1 parent a39307a commit 9a68a32

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,21 @@
1414
from typing import TYPE_CHECKING
1515

1616
if TYPE_CHECKING:
17-
from typing import Any, List, Callable, Dict, Union, Optional
17+
from typing import (
18+
Any,
19+
List,
20+
Callable,
21+
Dict,
22+
Union,
23+
Optional,
24+
AsyncIterator,
25+
Iterator,
26+
)
1827
from uuid import UUID
1928

29+
2030
try:
21-
from langchain_core.messages import BaseMessage
31+
from langchain_core.messages import BaseMessage, MessageStreamEvent
2232
from langchain_core.outputs import LLMResult
2333
from langchain_core.callbacks import (
2434
manager,
@@ -735,20 +745,44 @@ def new_stream(self, *args, **kwargs):
735745
return f(self, *args, **kwargs)
736746

737747
# Create a span that will act as the parent for all callback-generated spans
738-
with sentry_sdk.start_span(
748+
span = sentry_sdk.start_span(
739749
op=OP.GEN_AI_INVOKE_AGENT,
740750
name="AgentExecutor.stream",
741751
origin=LangchainIntegration.origin,
742-
) as span:
743-
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
752+
)
753+
span.__enter__()
744754

745-
if hasattr(self, "agent") and hasattr(self.agent, "llm"):
746-
model_name = getattr(self.agent.llm, "model_name", None) or getattr(
747-
self.agent.llm, "model", None
748-
)
749-
if model_name:
750-
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
755+
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
751756

752-
return f(self, *args, **kwargs)
757+
if hasattr(self, "agent") and hasattr(self.agent, "llm"):
758+
model_name = getattr(self.agent.llm, "model_name", None) or getattr(
759+
self.agent.llm, "model", None
760+
)
761+
if model_name:
762+
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
763+
764+
result = f(self, *args, **kwargs)
765+
old_iterator = result
766+
767+
def new_iterator():
768+
# type: () -> Iterator[MessageStreamEvent]
769+
for event in old_iterator:
770+
# import ipdb; ipdb.set_trace()
771+
yield event
772+
span.__exit__(None, None, None)
773+
774+
async def new_iterator_async():
775+
# type: () -> AsyncIterator[MessageStreamEvent]
776+
async for event in old_iterator:
777+
yield event
778+
779+
span.__exit__(None, None, None)
780+
781+
if str(type(result)) == "<class 'async_generator'>":
782+
result = new_iterator_async()
783+
else:
784+
result = new_iterator()
785+
786+
return result
753787

754788
return new_stream

0 commit comments

Comments
 (0)