Skip to content

Commit a34799c

Browse files
committed
refactor streaming
1 parent 9e03a64 commit a34799c

File tree

2 files changed

+58
-80
lines changed

2 files changed

+58
-80
lines changed

sentry_sdk/integrations/google_genai/streaming.py

Lines changed: 50 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
TYPE_CHECKING,
33
Any,
44
List,
5+
TypedDict,
6+
Optional,
57
)
68

79
from sentry_sdk.ai.utils import set_data_normalized
@@ -11,27 +13,34 @@
1113
safe_serialize,
1214
)
1315
from .utils import (
14-
get_model_name,
15-
wrapped_config_with_tools,
1616
extract_tool_calls,
1717
extract_finish_reasons,
1818
extract_contents_text,
1919
extract_usage_data,
20+
UsageData,
2021
)
2122

2223
if TYPE_CHECKING:
2324
from sentry_sdk.tracing import Span
2425
from google.genai.types import GenerateContentResponse
2526

2627

28+
class AccumulatedResponse(TypedDict):
29+
id: Optional[str]
30+
model: Optional[str]
31+
text: str
32+
finish_reasons: List[str]
33+
tool_calls: List[str]
34+
usage_metadata: UsageData
35+
36+
2737
def accumulate_streaming_response(chunks):
2838
# type: (List[GenerateContentResponse]) -> dict[str, Any]
2939
"""Accumulate streaming chunks into a single response-like object."""
3040
accumulated_text = []
3141
finish_reasons = []
3242
tool_calls = []
33-
total_prompt_tokens = 0
34-
total_tool_use_prompt_tokens = 0
43+
total_input_tokens = 0
3544
total_output_tokens = 0
3645
total_tokens = 0
3746
total_cached_tokens = 0
@@ -59,63 +68,26 @@ def accumulate_streaming_response(chunks):
5968
tool_calls.extend(extracted_tool_calls)
6069

6170
# Accumulate token usage
62-
if getattr(chunk, "usage_metadata", None):
63-
usage = chunk.usage_metadata
64-
if getattr(usage, "prompt_token_count", None):
65-
total_prompt_tokens = max(total_prompt_tokens, usage.prompt_token_count)
66-
if getattr(usage, "tool_use_prompt_token_count", None):
67-
total_tool_use_prompt_tokens = max(
68-
total_tool_use_prompt_tokens, usage.tool_use_prompt_token_count
69-
)
70-
if getattr(usage, "candidates_token_count", None):
71-
total_output_tokens += usage.candidates_token_count
72-
if getattr(usage, "cached_content_token_count", None):
73-
total_cached_tokens = max(
74-
total_cached_tokens, usage.cached_content_token_count
75-
)
76-
if getattr(usage, "thoughts_token_count", None):
77-
total_reasoning_tokens += usage.thoughts_token_count
78-
if getattr(usage, "total_token_count", None):
79-
# Only use the final total_token_count from the last chunk
80-
total_tokens = usage.total_token_count
71+
extracted_usage_data = extract_usage_data(chunk)
72+
total_input_tokens += extracted_usage_data["input_tokens"]
73+
total_output_tokens += extracted_usage_data["output_tokens"]
74+
total_cached_tokens += extracted_usage_data["input_tokens_cached"]
75+
total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"]
76+
total_tokens += extracted_usage_data["total_tokens"]
8177

8278
# Create a synthetic response object with accumulated data
83-
accumulated_response = {
84-
"text": "".join(accumulated_text),
85-
"finish_reasons": finish_reasons,
86-
"tool_calls": tool_calls,
87-
"usage_metadata": {
88-
"prompt_token_count": total_prompt_tokens,
89-
"candidates_token_count": total_output_tokens, # Keep original output tokens
90-
"cached_content_token_count": total_cached_tokens,
91-
"thoughts_token_count": total_reasoning_tokens,
92-
"total_token_count": (
93-
total_tokens
94-
if total_tokens > 0
95-
else (
96-
total_prompt_tokens
97-
+ total_tool_use_prompt_tokens
98-
+ total_output_tokens
99-
+ total_reasoning_tokens
100-
+ total_cached_tokens
101-
)
102-
),
103-
},
104-
}
105-
106-
# Add optional token counts if present
107-
if total_tool_use_prompt_tokens > 0:
108-
accumulated_response["usage_metadata"][
109-
"tool_use_prompt_token_count"
110-
] = total_tool_use_prompt_tokens
111-
if total_cached_tokens > 0:
112-
accumulated_response["usage_metadata"][
113-
"cached_content_token_count"
114-
] = total_cached_tokens
115-
if total_reasoning_tokens > 0:
116-
accumulated_response["usage_metadata"][
117-
"thoughts_token_count"
118-
] = total_reasoning_tokens
79+
accumulated_response = AccumulatedResponse(
80+
text="".join(accumulated_text),
81+
finish_reasons=finish_reasons,
82+
tool_calls=tool_calls,
83+
usage_metadata=UsageData(
84+
input_tokens=total_input_tokens,
85+
output_tokens=total_output_tokens,
86+
input_tokens_cached=total_cached_tokens,
87+
output_tokens_reasoning=total_reasoning_tokens,
88+
total_tokens=total_tokens,
89+
),
90+
)
11991

12092
if response_id:
12193
accumulated_response["id"] = response_id
@@ -160,28 +132,34 @@ def set_span_data_for_streaming_response(span, integration, accumulated_response
160132
if accumulated_response.get("model"):
161133
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"])
162134

163-
# Set token usage
164-
usage_data = extract_usage_data(accumulated_response)
165-
166-
if usage_data["input_tokens"]:
167-
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage_data["input_tokens"])
135+
if accumulated_response["usage_metadata"]["input_tokens"]:
136+
span.set_data(
137+
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS,
138+
accumulated_response["usage_metadata"]["input_tokens"],
139+
)
168140

169-
if usage_data["input_tokens_cached"]:
141+
if accumulated_response["usage_metadata"]["input_tokens_cached"]:
170142
span.set_data(
171143
SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
172-
usage_data["input_tokens_cached"],
144+
accumulated_response["usage_metadata"]["input_tokens_cached"],
173145
)
174146

175147
# Output tokens already include reasoning tokens from extract_usage_data
176-
if usage_data["output_tokens"]:
177-
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage_data["output_tokens"])
148+
if accumulated_response["usage_metadata"]["output_tokens"]:
149+
span.set_data(
150+
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS,
151+
accumulated_response["usage_metadata"]["output_tokens"],
152+
)
178153

179-
if usage_data["output_tokens_reasoning"]:
154+
if accumulated_response["usage_metadata"]["output_tokens_reasoning"]:
180155
span.set_data(
181156
SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
182-
usage_data["output_tokens_reasoning"],
157+
accumulated_response["usage_metadata"]["output_tokens_reasoning"],
183158
)
184159

185160
# Set total token count if available
186-
if usage_data["total_tokens"]:
187-
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage_data["total_tokens"])
161+
if accumulated_response["usage_metadata"]["total_tokens"]:
162+
span.set_data(
163+
SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS,
164+
accumulated_response["usage_metadata"]["total_tokens"],
165+
)

tests/integrations/google_genai/test_google_genai.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
410410
"usageMetadata": {
411411
"promptTokenCount": 10,
412412
"candidatesTokenCount": 2,
413-
"totalTokenCount": 0, # Not set in intermediate chunks
413+
"totalTokenCount": 12, # Not set in intermediate chunks
414414
},
415415
"responseId": "response-id-stream-123",
416416
"modelVersion": "gemini-1.5-flash",
@@ -429,7 +429,7 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
429429
"usageMetadata": {
430430
"promptTokenCount": 10,
431431
"candidatesTokenCount": 3,
432-
"totalTokenCount": 0,
432+
"totalTokenCount": 13,
433433
},
434434
}
435435

@@ -446,8 +446,8 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
446446
],
447447
"usageMetadata": {
448448
"promptTokenCount": 10,
449-
"candidatesTokenCount": 7, # Total output tokens across all chunks
450-
"totalTokenCount": 22, # Final total from last chunk
449+
"candidatesTokenCount": 7,
450+
"totalTokenCount": 25,
451451
"cachedContentTokenCount": 5,
452452
"thoughtsTokenCount": 3,
453453
},
@@ -505,17 +505,17 @@ def test_streaming_generate_content(sentry_init, capture_events, mock_genai_clie
505505

506506
# Verify token counts - should reflect accumulated values
507507
# Input tokens: max of all chunks = 10
508-
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
509-
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 10
508+
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
509+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 30
510510

511511
# Output tokens: candidates (2 + 3 + 7 = 12) + reasoning (3) = 15
512512
# Note: output_tokens includes both candidates and reasoning tokens
513513
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
514514
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 15
515515

516516
# Total tokens: from the last chunk
517-
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 22
518-
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 22
517+
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
518+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 50
519519

520520
# Cached tokens: max of all chunks = 5
521521
assert chat_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED] == 5

0 commit comments

Comments
 (0)