|
2 | 2 | TYPE_CHECKING, |
3 | 3 | Any, |
4 | 4 | List, |
| 5 | + TypedDict, |
| 6 | + Optional, |
5 | 7 | ) |
6 | 8 |
|
7 | 9 | from sentry_sdk.ai.utils import set_data_normalized |
|
11 | 13 | safe_serialize, |
12 | 14 | ) |
13 | 15 | from .utils import ( |
14 | | - get_model_name, |
15 | | - wrapped_config_with_tools, |
16 | 16 | extract_tool_calls, |
17 | 17 | extract_finish_reasons, |
18 | 18 | extract_contents_text, |
19 | 19 | extract_usage_data, |
| 20 | + UsageData, |
20 | 21 | ) |
21 | 22 |
|
22 | 23 | if TYPE_CHECKING: |
23 | 24 | from sentry_sdk.tracing import Span |
24 | 25 | from google.genai.types import GenerateContentResponse |
25 | 26 |
|
26 | 27 |
|
| 28 | +class AccumulatedResponse(TypedDict): |
| 29 | + id: Optional[str] |
| 30 | + model: Optional[str] |
| 31 | + text: str |
| 32 | + finish_reasons: List[str] |
| 33 | + tool_calls: List[str] |
| 34 | + usage_metadata: UsageData |
| 35 | + |
| 36 | + |
27 | 37 | def accumulate_streaming_response(chunks): |
28 | 38 | # type: (List[GenerateContentResponse]) -> dict[str, Any] |
29 | 39 | """Accumulate streaming chunks into a single response-like object.""" |
30 | 40 | accumulated_text = [] |
31 | 41 | finish_reasons = [] |
32 | 42 | tool_calls = [] |
33 | | - total_prompt_tokens = 0 |
34 | | - total_tool_use_prompt_tokens = 0 |
| 43 | + total_input_tokens = 0 |
35 | 44 | total_output_tokens = 0 |
36 | 45 | total_tokens = 0 |
37 | 46 | total_cached_tokens = 0 |
@@ -59,63 +68,26 @@ def accumulate_streaming_response(chunks): |
59 | 68 | tool_calls.extend(extracted_tool_calls) |
60 | 69 |
|
61 | 70 | # Accumulate token usage |
62 | | - if getattr(chunk, "usage_metadata", None): |
63 | | - usage = chunk.usage_metadata |
64 | | - if getattr(usage, "prompt_token_count", None): |
65 | | - total_prompt_tokens = max(total_prompt_tokens, usage.prompt_token_count) |
66 | | - if getattr(usage, "tool_use_prompt_token_count", None): |
67 | | - total_tool_use_prompt_tokens = max( |
68 | | - total_tool_use_prompt_tokens, usage.tool_use_prompt_token_count |
69 | | - ) |
70 | | - if getattr(usage, "candidates_token_count", None): |
71 | | - total_output_tokens += usage.candidates_token_count |
72 | | - if getattr(usage, "cached_content_token_count", None): |
73 | | - total_cached_tokens = max( |
74 | | - total_cached_tokens, usage.cached_content_token_count |
75 | | - ) |
76 | | - if getattr(usage, "thoughts_token_count", None): |
77 | | - total_reasoning_tokens += usage.thoughts_token_count |
78 | | - if getattr(usage, "total_token_count", None): |
79 | | - # Only use the final total_token_count from the last chunk |
80 | | - total_tokens = usage.total_token_count |
| 71 | + extracted_usage_data = extract_usage_data(chunk) |
| 72 | + total_input_tokens += extracted_usage_data["input_tokens"] |
| 73 | + total_output_tokens += extracted_usage_data["output_tokens"] |
| 74 | + total_cached_tokens += extracted_usage_data["input_tokens_cached"] |
| 75 | + total_reasoning_tokens += extracted_usage_data["output_tokens_reasoning"] |
| 76 | + total_tokens += extracted_usage_data["total_tokens"] |
81 | 77 |
|
82 | 78 | # Create a synthetic response object with accumulated data |
83 | | - accumulated_response = { |
84 | | - "text": "".join(accumulated_text), |
85 | | - "finish_reasons": finish_reasons, |
86 | | - "tool_calls": tool_calls, |
87 | | - "usage_metadata": { |
88 | | - "prompt_token_count": total_prompt_tokens, |
89 | | - "candidates_token_count": total_output_tokens, # Keep original output tokens |
90 | | - "cached_content_token_count": total_cached_tokens, |
91 | | - "thoughts_token_count": total_reasoning_tokens, |
92 | | - "total_token_count": ( |
93 | | - total_tokens |
94 | | - if total_tokens > 0 |
95 | | - else ( |
96 | | - total_prompt_tokens |
97 | | - + total_tool_use_prompt_tokens |
98 | | - + total_output_tokens |
99 | | - + total_reasoning_tokens |
100 | | - + total_cached_tokens |
101 | | - ) |
102 | | - ), |
103 | | - }, |
104 | | - } |
105 | | - |
106 | | - # Add optional token counts if present |
107 | | - if total_tool_use_prompt_tokens > 0: |
108 | | - accumulated_response["usage_metadata"][ |
109 | | - "tool_use_prompt_token_count" |
110 | | - ] = total_tool_use_prompt_tokens |
111 | | - if total_cached_tokens > 0: |
112 | | - accumulated_response["usage_metadata"][ |
113 | | - "cached_content_token_count" |
114 | | - ] = total_cached_tokens |
115 | | - if total_reasoning_tokens > 0: |
116 | | - accumulated_response["usage_metadata"][ |
117 | | - "thoughts_token_count" |
118 | | - ] = total_reasoning_tokens |
| 79 | + accumulated_response = AccumulatedResponse( |
| 80 | + text="".join(accumulated_text), |
| 81 | + finish_reasons=finish_reasons, |
| 82 | + tool_calls=tool_calls, |
| 83 | + usage_metadata=UsageData( |
| 84 | + input_tokens=total_input_tokens, |
| 85 | + output_tokens=total_output_tokens, |
| 86 | + input_tokens_cached=total_cached_tokens, |
| 87 | + output_tokens_reasoning=total_reasoning_tokens, |
| 88 | + total_tokens=total_tokens, |
| 89 | + ), |
| 90 | + ) |
119 | 91 |
|
120 | 92 | if response_id: |
121 | 93 | accumulated_response["id"] = response_id |
@@ -160,28 +132,34 @@ def set_span_data_for_streaming_response(span, integration, accumulated_response |
160 | 132 | if accumulated_response.get("model"): |
161 | 133 | span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, accumulated_response["model"]) |
162 | 134 |
|
163 | | - # Set token usage |
164 | | - usage_data = extract_usage_data(accumulated_response) |
165 | | - |
166 | | - if usage_data["input_tokens"]: |
167 | | - span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, usage_data["input_tokens"]) |
| 135 | + if accumulated_response["usage_metadata"]["input_tokens"]: |
| 136 | + span.set_data( |
| 137 | + SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, |
| 138 | + accumulated_response["usage_metadata"]["input_tokens"], |
| 139 | + ) |
168 | 140 |
|
169 | | - if usage_data["input_tokens_cached"]: |
| 141 | + if accumulated_response["usage_metadata"]["input_tokens_cached"]: |
170 | 142 | span.set_data( |
171 | 143 | SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED, |
172 | | - usage_data["input_tokens_cached"], |
| 144 | + accumulated_response["usage_metadata"]["input_tokens_cached"], |
173 | 145 | ) |
174 | 146 |
|
175 | 147 | # Output tokens already include reasoning tokens from extract_usage_data |
176 | | - if usage_data["output_tokens"]: |
177 | | - span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, usage_data["output_tokens"]) |
| 148 | + if accumulated_response["usage_metadata"]["output_tokens"]: |
| 149 | + span.set_data( |
| 150 | + SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, |
| 151 | + accumulated_response["usage_metadata"]["output_tokens"], |
| 152 | + ) |
178 | 153 |
|
179 | | - if usage_data["output_tokens_reasoning"]: |
| 154 | + if accumulated_response["usage_metadata"]["output_tokens_reasoning"]: |
180 | 155 | span.set_data( |
181 | 156 | SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING, |
182 | | - usage_data["output_tokens_reasoning"], |
| 157 | + accumulated_response["usage_metadata"]["output_tokens_reasoning"], |
183 | 158 | ) |
184 | 159 |
|
185 | 160 | # Set total token count if available |
186 | | - if usage_data["total_tokens"]: |
187 | | - span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, usage_data["total_tokens"]) |
| 161 | + if accumulated_response["usage_metadata"]["total_tokens"]: |
| 162 | + span.set_data( |
| 163 | + SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, |
| 164 | + accumulated_response["usage_metadata"]["total_tokens"], |
| 165 | + ) |
0 commit comments