Skip to content

Commit 540409c

Browse files
committed
fix: update type hints for async message streaming in GeminiLLM
1 parent 3c5377c commit 540409c

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

plugins/gemini/vision_agents/plugins/gemini/gemini_llm.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import uuid
3-
from typing import Optional, List, TYPE_CHECKING, Any, Dict
3+
from typing import Optional, List, TYPE_CHECKING, Any, Dict, AsyncIterator
44

55
from google.genai.client import AsyncClient, Client
66
from google.genai import types
@@ -79,21 +79,6 @@ async def simple_response(
7979
"""
8080
return await self.send_message(message=text)
8181

82-
def _iterate_stream_blocking(self, iterator):
83-
"""Helper method to iterate over a blocking stream iterator.
84-
85-
This method runs in a thread pool to avoid blocking the async event loop.
86-
It collects all chunks and returns them as a list.
87-
"""
88-
chunks = []
89-
try:
90-
for chunk in iterator:
91-
chunks.append(chunk)
92-
except Exception as e:
93-
# Return error as last element
94-
chunks.append(e)
95-
return chunks
96-
9782
async def send_message(self, *args, **kwargs):
9883
"""
9984
send_message gives you full support/access to the native Gemini chat send message method
@@ -125,7 +110,7 @@ async def send_message(self, *args, **kwargs):
125110
kwargs["config"] = cfg
126111

127112
# Generate content using the client
128-
iterator = await self.chat.send_message_stream(*args, **kwargs)
113+
iterator: AsyncIterator[GenerateContentResponse] = self.chat.send_message_stream(*args, **kwargs) # type: ignore[assignment]
129114
text_parts : List[str] = []
130115
final_chunk = None
131116
pending_calls: List[NormalizedToolCallItem] = []
@@ -174,43 +159,34 @@ async def send_message(self, *args, **kwargs):
174159
sanitized_res = {}
175160
for k, v in res.items():
176161
sanitized_res[k] = self._sanitize_tool_output(v)
162+
177163
parts.append(
178164
types.Part.from_function_response(
179165
name=tc["name"], response=sanitized_res
180166
)
181167
)
182168

183-
# Send function responses with tools config - wrap in thread pool
184-
def _get_follow_up_iter():
185-
return chat.send_message_stream(parts, config=cfg_with_tools) # type: ignore[arg-type]
186-
187-
follow_up_iter = await asyncio.to_thread(_get_follow_up_iter)
188-
follow_up_chunks = await asyncio.to_thread(
189-
self._iterate_stream_blocking, follow_up_iter
190-
)
191-
192-
# Check if last element is an exception
193-
if follow_up_chunks and isinstance(follow_up_chunks[-1], Exception):
194-
raise follow_up_chunks[-1]
195-
169+
# Send function responses with tools config
170+
follow_up_iter: AsyncIterator[GenerateContentResponse] = self.chat.send_message_stream(parts, config=cfg_with_tools) # type: ignore[arg-type,assignment]
196171
follow_up_text_parts: List[str] = []
197172
follow_up_last = None
198173
next_calls = []
199-
200-
for idx, chk in enumerate(follow_up_chunks):
174+
follow_up_idx = 0
175+
176+
async for chk in follow_up_iter:
201177
follow_up_last = chk
202178
# TODO: unclear if this is correct (item_id and idx)
203-
self._standardize_and_emit_event(
204-
chk, follow_up_text_parts, item_id, idx
205-
)
179+
self._standardize_and_emit_event(chk, follow_up_text_parts, item_id, follow_up_idx)
206180

207181
# Check for new function calls
208182
try:
209183
chunk_calls = self._extract_tool_calls_from_stream_chunk(chk)
210184
next_calls.extend(chunk_calls)
211185
except Exception:
212186
pass
213-
187+
188+
follow_up_idx += 1
189+
214190
current_calls = next_calls
215191
rounds += 1
216192

0 commit comments

Comments
 (0)