|
1 | 1 | import asyncio |
2 | 2 | import uuid |
3 | | -from typing import Optional, List, TYPE_CHECKING, Any, Dict |
| 3 | +from typing import Optional, List, TYPE_CHECKING, Any, Dict, AsyncIterator |
4 | 4 |
|
5 | 5 | from google.genai.client import AsyncClient, Client |
6 | 6 | from google.genai import types |
@@ -79,21 +79,6 @@ async def simple_response( |
79 | 79 | """ |
80 | 80 | return await self.send_message(message=text) |
81 | 81 |
|
82 | | - def _iterate_stream_blocking(self, iterator): |
83 | | - """Helper method to iterate over a blocking stream iterator. |
84 | | -
|
85 | | - This method runs in a thread pool to avoid blocking the async event loop. |
86 | | - It collects all chunks and returns them as a list. |
87 | | - """ |
88 | | - chunks = [] |
89 | | - try: |
90 | | - for chunk in iterator: |
91 | | - chunks.append(chunk) |
92 | | - except Exception as e: |
93 | | - # Return error as last element |
94 | | - chunks.append(e) |
95 | | - return chunks |
96 | | - |
97 | 82 | async def send_message(self, *args, **kwargs): |
98 | 83 | """ |
99 | 84 | send_message gives you full support/access to the native Gemini chat send message method |
@@ -125,7 +110,7 @@ async def send_message(self, *args, **kwargs): |
125 | 110 | kwargs["config"] = cfg |
126 | 111 |
|
127 | 112 | # Generate content using the client |
128 | | - iterator = await self.chat.send_message_stream(*args, **kwargs) |
| 113 | + iterator: AsyncIterator[GenerateContentResponse] = self.chat.send_message_stream(*args, **kwargs) # type: ignore[assignment] |
129 | 114 | text_parts : List[str] = [] |
130 | 115 | final_chunk = None |
131 | 116 | pending_calls: List[NormalizedToolCallItem] = [] |
@@ -174,43 +159,34 @@ async def send_message(self, *args, **kwargs): |
174 | 159 | sanitized_res = {} |
175 | 160 | for k, v in res.items(): |
176 | 161 | sanitized_res[k] = self._sanitize_tool_output(v) |
| 162 | + |
177 | 163 | parts.append( |
178 | 164 | types.Part.from_function_response( |
179 | 165 | name=tc["name"], response=sanitized_res |
180 | 166 | ) |
181 | 167 | ) |
182 | 168 |
|
183 | | - # Send function responses with tools config - wrap in thread pool |
184 | | - def _get_follow_up_iter(): |
185 | | - return chat.send_message_stream(parts, config=cfg_with_tools) # type: ignore[arg-type] |
186 | | - |
187 | | - follow_up_iter = await asyncio.to_thread(_get_follow_up_iter) |
188 | | - follow_up_chunks = await asyncio.to_thread( |
189 | | - self._iterate_stream_blocking, follow_up_iter |
190 | | - ) |
191 | | - |
192 | | - # Check if last element is an exception |
193 | | - if follow_up_chunks and isinstance(follow_up_chunks[-1], Exception): |
194 | | - raise follow_up_chunks[-1] |
195 | | - |
| 169 | + # Send function responses with tools config |
| 170 | + follow_up_iter: AsyncIterator[GenerateContentResponse] = self.chat.send_message_stream(parts, config=cfg_with_tools) # type: ignore[arg-type,assignment] |
196 | 171 | follow_up_text_parts: List[str] = [] |
197 | 172 | follow_up_last = None |
198 | 173 | next_calls = [] |
199 | | - |
200 | | - for idx, chk in enumerate(follow_up_chunks): |
| 174 | + follow_up_idx = 0 |
| 175 | + |
| 176 | + async for chk in follow_up_iter: |
201 | 177 | follow_up_last = chk |
202 | 178 | # TODO: unclear if this is correct (item_id and idx) |
203 | | - self._standardize_and_emit_event( |
204 | | - chk, follow_up_text_parts, item_id, idx |
205 | | - ) |
| 179 | + self._standardize_and_emit_event(chk, follow_up_text_parts, item_id, follow_up_idx) |
206 | 180 |
|
207 | 181 | # Check for new function calls |
208 | 182 | try: |
209 | 183 | chunk_calls = self._extract_tool_calls_from_stream_chunk(chk) |
210 | 184 | next_calls.extend(chunk_calls) |
211 | 185 | except Exception: |
212 | 186 | pass |
213 | | - |
| 187 | + |
| 188 | + follow_up_idx += 1 |
| 189 | + |
214 | 190 | current_calls = next_calls |
215 | 191 | rounds += 1 |
216 | 192 |
|
|
0 commit comments