Skip to content

Commit 1140d93

Browse files
committed
add correct response output
1 parent f013b48 commit 1140d93

File tree

1 file changed

+110
-23
lines changed

1 file changed

+110
-23
lines changed

sentry_sdk/integrations/langgraph.py

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ def _get_graph_name(graph_obj):
5151
return None
5252

5353

54+
def _normalize_langgraph_message(message):
55+
# type: (Any) -> Any
56+
if not hasattr(message, "content"):
57+
return None
58+
parsed = {"role": getattr(message, "type", None), "content": message.content}
59+
60+
for attr in ["name", "tool_calls", "function_call", "tool_call_id"]:
61+
if hasattr(message, attr):
62+
value = getattr(message, attr)
63+
if value is not None:
64+
parsed[attr] = value
65+
66+
return parsed
67+
68+
5469
def _parse_langgraph_messages(state):
5570
# type: (Any) -> Optional[List[Any]]
5671
if not state:
@@ -74,15 +89,9 @@ def _parse_langgraph_messages(state):
7489
normalized_messages = []
7590
for message in messages:
7691
try:
77-
if hasattr(message, "content"):
78-
parsed = {"content": message.content}
79-
for attr in ["name", "tool_calls", "function_call"]:
80-
if hasattr(message, attr):
81-
value = getattr(message, attr)
82-
if value is not None:
83-
parsed[attr] = value
84-
normalized_messages.append(parsed)
85-
92+
normalized = _normalize_langgraph_message(message)
93+
if normalized:
94+
normalized_messages.append(normalized)
8695
except Exception:
8796
continue
8897

@@ -150,21 +159,23 @@ def new_invoke(self, *args, **kwargs):
150159
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
151160
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)
152161

162+
# Store input messages to later compare with output
163+
input_messages = None
153164
if (
154165
len(args) > 0
155166
and should_send_default_pii()
156167
and integration.include_prompts
157168
):
158-
parsed_messages = _parse_langgraph_messages(args[0])
159-
if parsed_messages:
160-
span.set_data(
169+
input_messages = _parse_langgraph_messages(args[0])
170+
if input_messages:
171+
set_data_normalized(
172+
span,
161173
SPANDATA.GEN_AI_REQUEST_MESSAGES,
162-
safe_serialize(parsed_messages),
174+
safe_serialize(input_messages),
163175
)
164176

165177
result = f(self, *args, **kwargs)
166-
if should_send_default_pii() and integration.include_prompts:
167-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
178+
_set_response_attributes(span, input_messages, result, integration)
168179
return result
169180

170181
return new_invoke
@@ -194,21 +205,97 @@ async def new_ainvoke(self, *args, **kwargs):
194205
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
195206
span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)
196207

208+
input_messages = None
197209
if (
198210
len(args) > 0
199211
and should_send_default_pii()
200212
and integration.include_prompts
201213
):
202-
parsed_messages = _parse_langgraph_messages(args[0])
203-
if parsed_messages:
204-
span.set_data(
214+
input_messages = _parse_langgraph_messages(args[0])
215+
if input_messages:
216+
set_data_normalized(
217+
span,
205218
SPANDATA.GEN_AI_REQUEST_MESSAGES,
206-
safe_serialize(parsed_messages),
219+
safe_serialize(input_messages),
207220
)
208-
result = await f(self, *args, **kwargs)
209-
if should_send_default_pii() and integration.include_prompts:
210-
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
211-
return result
221+
222+
result = await f(self, *args, **kwargs)
223+
_set_response_attributes(span, input_messages, result, integration)
224+
return result
212225

213226
new_ainvoke.__wrapped__ = True
214227
return new_ainvoke
228+
229+
230+
def _get_new_messages(input_messages, output_messages):
231+
# type: (Optional[List[Any]], Optional[List[Any]]) -> Optional[List[Any]]
232+
"""Extract only the new messages added during this invocation."""
233+
if not output_messages:
234+
return None
235+
236+
if not input_messages:
237+
return output_messages
238+
239+
# only return the new messages, aka the output messages that are not in the input messages
240+
input_count = len(input_messages)
241+
new_messages = (
242+
output_messages[input_count:] if len(output_messages) > input_count else []
243+
)
244+
245+
return new_messages if new_messages else None
246+
247+
248+
def _extract_llm_response_text(messages):
249+
# type: (Optional[List[Any]]) -> Optional[str]
250+
if not messages:
251+
return None
252+
253+
for message in reversed(messages):
254+
if isinstance(message, dict):
255+
role = message.get("role")
256+
if role in ["assistant", "ai"]:
257+
content = message.get("content")
258+
if content and isinstance(content, str):
259+
return content
260+
261+
return None
262+
263+
264+
def _extract_tool_calls(messages):
265+
# type: (Optional[List[Any]]) -> Optional[List[Any]]
266+
if not messages:
267+
return None
268+
269+
tool_calls = []
270+
for message in messages:
271+
if isinstance(message, dict):
272+
msg_tool_calls = message.get("tool_calls")
273+
if msg_tool_calls and isinstance(msg_tool_calls, list):
274+
tool_calls.extend(msg_tool_calls)
275+
276+
return tool_calls if tool_calls else None
277+
278+
279+
def _set_response_attributes(span, input_messages, result, integration):
280+
# type: (Any, Optional[List[Any]], Any, LanggraphIntegration) -> None
281+
if not (should_send_default_pii() and integration.include_prompts):
282+
return
283+
284+
parsed_response_messages = _parse_langgraph_messages(result)
285+
new_messages = _get_new_messages(input_messages, parsed_response_messages)
286+
287+
llm_response_text = _extract_llm_response_text(new_messages)
288+
if llm_response_text:
289+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text)
290+
elif new_messages:
291+
set_data_normalized(
292+
span, SPANDATA.GEN_AI_RESPONSE_TEXT, safe_serialize(new_messages)
293+
)
294+
else:
295+
set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
296+
297+
tool_calls = _extract_tool_calls(new_messages)
298+
if tool_calls:
299+
set_data_normalized(
300+
span, SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, safe_serialize(tool_calls)
301+
)

0 commit comments

Comments
 (0)