Skip to content

Commit 02889ff

Browse files
committed
fix(llma): tool calls in streaming Gemini
1 parent 1ab9e9d commit 02889ff

File tree

2 files changed

+83
-17
lines changed

2 files changed

+83
-17
lines changed

posthog/ai/gemini/gemini.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,11 @@ def generator():
309309
if chunk_usage:
310310
usage_stats.update(chunk_usage)
311311

312-
# Extract content from chunk
313-
content = extract_gemini_content_from_chunk(chunk)
312+
# Extract content from chunk (now returns content blocks)
313+
content_block = extract_gemini_content_from_chunk(chunk)
314314

315-
if content is not None:
316-
accumulated_content.append(content)
315+
if content_block is not None:
316+
accumulated_content.append(content_block)
317317

318318
yield chunk
319319

@@ -349,7 +349,7 @@ def _capture_streaming_event(
349349
kwargs: Dict[str, Any],
350350
usage_stats: Dict[str, int],
351351
latency: float,
352-
output: str,
352+
output: Any,
353353
):
354354
from posthog.ai.types import StreamingEventData
355355
from posthog.ai.gemini.gemini_converter import standardize_gemini_usage

posthog/ai/gemini/gemini_converter.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,44 +307,110 @@ def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats:
307307
return usage
308308

309309

310-
def extract_gemini_content_from_chunk(chunk: Any) -> Optional[str]:
310+
def extract_gemini_content_from_chunk(chunk: Any) -> Optional[Dict[str, Any]]:
311311
"""
312-
Extract text content from a Gemini streaming chunk.
312+
Extract content (text or function call) from a Gemini streaming chunk.
313313
314314
Args:
315315
chunk: Streaming chunk from Gemini API
316316
317317
Returns:
318-
Text content if present, None otherwise
318+
Content block dictionary if present, None otherwise
319319
"""
320320

321+
# Check for text content
321322
if hasattr(chunk, "text") and chunk.text:
322-
return chunk.text
323+
return {"type": "text", "text": chunk.text}
324+
325+
# Check for function calls in candidates
326+
if hasattr(chunk, "candidates") and chunk.candidates:
327+
for candidate in chunk.candidates:
328+
if hasattr(candidate, "content") and candidate.content:
329+
if hasattr(candidate.content, "parts") and candidate.content.parts:
330+
for part in candidate.content.parts:
331+
# Check for function_call part
332+
if hasattr(part, "function_call") and part.function_call:
333+
function_call = part.function_call
334+
return {
335+
"type": "function",
336+
"function": {
337+
"name": function_call.name,
338+
"arguments": function_call.args,
339+
},
340+
}
341+
# Also check for text in parts
342+
elif hasattr(part, "text") and part.text:
343+
return {"type": "text", "text": part.text}
323344

324345
return None
325346

326347

327348
def format_gemini_streaming_output(
328-
accumulated_content: Union[str, List[str]],
349+
accumulated_content: Union[str, List[Any]],
329350
) -> List[FormattedMessage]:
330351
"""
331352
Format the final output from Gemini streaming.
332353
333354
Args:
334-
accumulated_content: Accumulated content from streaming (string or list of strings)
355+
accumulated_content: Accumulated content from streaming (string, list of strings, or list of content blocks)
335356
336357
Returns:
337358
List of formatted messages
338359
"""
339360

340-
# Handle list of strings
341-
if isinstance(accumulated_content, list):
342-
text = "".join(str(item) for item in accumulated_content)
361+
# Handle legacy string input (backward compatibility)
362+
if isinstance(accumulated_content, str):
363+
return [
364+
{
365+
"role": "assistant",
366+
"content": [{"type": "text", "text": accumulated_content}],
367+
}
368+
]
343369

344-
else:
345-
text = str(accumulated_content)
370+
# Handle list input
371+
if isinstance(accumulated_content, list):
372+
content: List[FormattedContentItem] = []
373+
text_parts = []
346374

347-
return [{"role": "assistant", "content": [{"type": "text", "text": text}]}]
375+
for item in accumulated_content:
376+
if isinstance(item, str):
377+
# Legacy support: accumulate strings
378+
text_parts.append(item)
379+
elif isinstance(item, dict):
380+
# New format: content blocks
381+
if item.get("type") == "text":
382+
text_parts.append(item.get("text", ""))
383+
elif item.get("type") == "function":
384+
# If we have accumulated text, add it first
385+
if text_parts:
386+
text_content: FormattedTextContent = {
387+
"type": "text",
388+
"text": "".join(text_parts),
389+
}
390+
content.append(text_content)
391+
text_parts = []
392+
393+
# Add the function call
394+
func_content: FormattedFunctionCall = {
395+
"type": "function",
396+
"function": item.get("function", {}),
397+
}
398+
content.append(func_content)
399+
400+
# Add any remaining text
401+
if text_parts:
402+
text_content: FormattedTextContent = {
403+
"type": "text",
404+
"text": "".join(text_parts),
405+
}
406+
content.append(text_content)
407+
408+
# If we have content, return it
409+
if content:
410+
return [{"role": "assistant", "content": content}]
411+
412+
# Fallback for empty or unexpected input
413+
return [{"role": "assistant", "content": [{"type": "text", "text": ""}]}]
348414

349415

350416
def standardize_gemini_usage(usage: Dict[str, Any]) -> TokenUsage:

0 commit comments

Comments
 (0)