Skip to content

Commit 4232cc4

Browse files
authored
Genai docs refactor & fixes (#22175)
* Improve GenAI docs * Clarify * Fix config updating * Implement streaming for other providers * Set openai base url if applied * Cast context size
1 parent 6a21b29 commit 4232cc4

File tree

7 files changed

+610
-94
lines changed

7 files changed

+610
-94
lines changed

docs/docs/configuration/genai/config.md

Lines changed: 103 additions & 86 deletions
Large diffs are not rendered by default.

frigate/genai/azure-openai.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,123 @@ def chat_with_tools(
167167
"tool_calls": None,
168168
"finish_reason": "error",
169169
}
170+
171+
async def chat_with_tools_stream(
172+
self,
173+
messages: list[dict[str, Any]],
174+
tools: Optional[list[dict[str, Any]]] = None,
175+
tool_choice: Optional[str] = "auto",
176+
):
177+
"""
178+
Stream chat with tools; yields content deltas then final message.
179+
180+
Implements streaming function calling/tool usage for Azure OpenAI models.
181+
"""
182+
try:
183+
openai_tool_choice = None
184+
if tool_choice:
185+
if tool_choice == "none":
186+
openai_tool_choice = "none"
187+
elif tool_choice == "auto":
188+
openai_tool_choice = "auto"
189+
elif tool_choice == "required":
190+
openai_tool_choice = "required"
191+
192+
request_params = {
193+
"model": self.genai_config.model,
194+
"messages": messages,
195+
"timeout": self.timeout,
196+
"stream": True,
197+
}
198+
199+
if tools:
200+
request_params["tools"] = tools
201+
if openai_tool_choice is not None:
202+
request_params["tool_choice"] = openai_tool_choice
203+
204+
# Use streaming API
205+
content_parts: list[str] = []
206+
tool_calls_by_index: dict[int, dict[str, Any]] = {}
207+
finish_reason = "stop"
208+
209+
stream = self.provider.chat.completions.create(**request_params)
210+
211+
for chunk in stream:
212+
if not chunk or not chunk.choices:
213+
continue
214+
215+
choice = chunk.choices[0]
216+
delta = choice.delta
217+
218+
# Check for finish reason
219+
if choice.finish_reason:
220+
finish_reason = choice.finish_reason
221+
222+
# Extract content deltas
223+
if delta.content:
224+
content_parts.append(delta.content)
225+
yield ("content_delta", delta.content)
226+
227+
# Extract tool calls
228+
if delta.tool_calls:
229+
for tc in delta.tool_calls:
230+
idx = tc.index
231+
fn = tc.function
232+
233+
if idx not in tool_calls_by_index:
234+
tool_calls_by_index[idx] = {
235+
"id": tc.id or "",
236+
"name": fn.name if fn and fn.name else "",
237+
"arguments": "",
238+
}
239+
240+
t = tool_calls_by_index[idx]
241+
if tc.id:
242+
t["id"] = tc.id
243+
if fn and fn.name:
244+
t["name"] = fn.name
245+
if fn and fn.arguments:
246+
t["arguments"] += fn.arguments
247+
248+
# Build final message
249+
full_content = "".join(content_parts).strip() or None
250+
251+
# Convert tool calls to list format
252+
tool_calls_list = None
253+
if tool_calls_by_index:
254+
tool_calls_list = []
255+
for tc in tool_calls_by_index.values():
256+
try:
257+
# Parse accumulated arguments as JSON
258+
parsed_args = json.loads(tc["arguments"])
259+
except (json.JSONDecodeError, Exception):
260+
parsed_args = tc["arguments"]
261+
262+
tool_calls_list.append(
263+
{
264+
"id": tc["id"],
265+
"name": tc["name"],
266+
"arguments": parsed_args,
267+
}
268+
)
269+
finish_reason = "tool_calls"
270+
271+
yield (
272+
"message",
273+
{
274+
"content": full_content,
275+
"tool_calls": tool_calls_list,
276+
"finish_reason": finish_reason,
277+
},
278+
)
279+
280+
except Exception as e:
281+
logger.warning("Azure OpenAI streaming returned an error: %s", str(e))
282+
yield (
283+
"message",
284+
{
285+
"content": None,
286+
"tool_calls": None,
287+
"finish_reason": "error",
288+
},
289+
)

frigate/genai/gemini.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Gemini Provider for Frigate AI."""
22

3+
import json
34
import logging
45
from typing import Any, Optional
56

@@ -273,3 +274,239 @@ def chat_with_tools(
273274
"tool_calls": None,
274275
"finish_reason": "error",
275276
}
277+
278+
async def chat_with_tools_stream(
279+
self,
280+
messages: list[dict[str, Any]],
281+
tools: Optional[list[dict[str, Any]]] = None,
282+
tool_choice: Optional[str] = "auto",
283+
):
284+
"""
285+
Stream chat with tools; yields content deltas then final message.
286+
287+
Implements streaming function calling/tool usage for Gemini models.
288+
"""
289+
try:
290+
# Convert messages to Gemini format
291+
gemini_messages = []
292+
for msg in messages:
293+
role = msg.get("role", "user")
294+
content = msg.get("content", "")
295+
296+
# Map roles to Gemini format
297+
if role == "system":
298+
# Gemini doesn't have system role, prepend to first user message
299+
if gemini_messages and gemini_messages[0].role == "user":
300+
gemini_messages[0].parts[
301+
0
302+
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
303+
else:
304+
gemini_messages.append(
305+
types.Content(
306+
role="user", parts=[types.Part.from_text(text=content)]
307+
)
308+
)
309+
elif role == "assistant":
310+
gemini_messages.append(
311+
types.Content(
312+
role="model", parts=[types.Part.from_text(text=content)]
313+
)
314+
)
315+
elif role == "tool":
316+
# Handle tool response
317+
function_response = {
318+
"name": msg.get("name", ""),
319+
"response": content,
320+
}
321+
gemini_messages.append(
322+
types.Content(
323+
role="function",
324+
parts=[
325+
types.Part.from_function_response(function_response)
326+
],
327+
)
328+
)
329+
else: # user
330+
gemini_messages.append(
331+
types.Content(
332+
role="user", parts=[types.Part.from_text(text=content)]
333+
)
334+
)
335+
336+
# Convert tools to Gemini format
337+
gemini_tools = None
338+
if tools:
339+
gemini_tools = []
340+
for tool in tools:
341+
if tool.get("type") == "function":
342+
func = tool.get("function", {})
343+
gemini_tools.append(
344+
types.Tool(
345+
function_declarations=[
346+
types.FunctionDeclaration(
347+
name=func.get("name", ""),
348+
description=func.get("description", ""),
349+
parameters=func.get("parameters", {}),
350+
)
351+
]
352+
)
353+
)
354+
355+
# Configure tool choice
356+
tool_config = None
357+
if tool_choice:
358+
if tool_choice == "none":
359+
tool_config = types.ToolConfig(
360+
function_calling_config=types.FunctionCallingConfig(mode="NONE")
361+
)
362+
elif tool_choice == "auto":
363+
tool_config = types.ToolConfig(
364+
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
365+
)
366+
elif tool_choice == "required":
367+
tool_config = types.ToolConfig(
368+
function_calling_config=types.FunctionCallingConfig(mode="ANY")
369+
)
370+
371+
# Build request config
372+
config_params = {"candidate_count": 1}
373+
374+
if gemini_tools:
375+
config_params["tools"] = gemini_tools
376+
377+
if tool_config:
378+
config_params["tool_config"] = tool_config
379+
380+
# Merge runtime_options
381+
if isinstance(self.genai_config.runtime_options, dict):
382+
config_params.update(self.genai_config.runtime_options)
383+
384+
# Use streaming API
385+
content_parts: list[str] = []
386+
tool_calls_by_index: dict[int, dict[str, Any]] = {}
387+
finish_reason = "stop"
388+
389+
response = self.provider.models.generate_content_stream(
390+
model=self.genai_config.model,
391+
contents=gemini_messages,
392+
config=types.GenerateContentConfig(**config_params),
393+
)
394+
395+
async for chunk in response:
396+
if not chunk or not chunk.candidates:
397+
continue
398+
399+
candidate = chunk.candidates[0]
400+
401+
# Check for finish reason
402+
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
403+
from google.genai.types import FinishReason
404+
405+
if candidate.finish_reason == FinishReason.STOP:
406+
finish_reason = "stop"
407+
elif candidate.finish_reason == FinishReason.MAX_TOKENS:
408+
finish_reason = "length"
409+
elif candidate.finish_reason in [
410+
FinishReason.SAFETY,
411+
FinishReason.RECITATION,
412+
]:
413+
finish_reason = "error"
414+
415+
# Extract content and tool calls from chunk
416+
if candidate.content and candidate.content.parts:
417+
for part in candidate.content.parts:
418+
if part.text:
419+
content_parts.append(part.text)
420+
yield ("content_delta", part.text)
421+
elif part.function_call:
422+
# Handle function call
423+
try:
424+
arguments = (
425+
dict(part.function_call.args)
426+
if part.function_call.args
427+
else {}
428+
)
429+
except Exception:
430+
arguments = {}
431+
432+
# Store tool call
433+
tool_call_id = part.function_call.name or ""
434+
tool_call_name = part.function_call.name or ""
435+
436+
# Check if we already have this tool call
437+
found_index = None
438+
for idx, tc in tool_calls_by_index.items():
439+
if tc["name"] == tool_call_name:
440+
found_index = idx
441+
break
442+
443+
if found_index is None:
444+
found_index = len(tool_calls_by_index)
445+
tool_calls_by_index[found_index] = {
446+
"id": tool_call_id,
447+
"name": tool_call_name,
448+
"arguments": "",
449+
}
450+
451+
# Accumulate arguments
452+
if arguments:
453+
tool_calls_by_index[found_index]["arguments"] += (
454+
json.dumps(arguments)
455+
if isinstance(arguments, dict)
456+
else str(arguments)
457+
)
458+
459+
# Build final message
460+
full_content = "".join(content_parts).strip() or None
461+
462+
# Convert tool calls to list format
463+
tool_calls_list = None
464+
if tool_calls_by_index:
465+
tool_calls_list = []
466+
for tc in tool_calls_by_index.values():
467+
try:
468+
# Try to parse accumulated arguments as JSON
469+
parsed_args = json.loads(tc["arguments"])
470+
except (json.JSONDecodeError, Exception):
471+
parsed_args = tc["arguments"]
472+
473+
tool_calls_list.append(
474+
{
475+
"id": tc["id"],
476+
"name": tc["name"],
477+
"arguments": parsed_args,
478+
}
479+
)
480+
finish_reason = "tool_calls"
481+
482+
yield (
483+
"message",
484+
{
485+
"content": full_content,
486+
"tool_calls": tool_calls_list,
487+
"finish_reason": finish_reason,
488+
},
489+
)
490+
491+
except errors.APIError as e:
492+
logger.warning("Gemini API error during streaming: %s", str(e))
493+
yield (
494+
"message",
495+
{
496+
"content": None,
497+
"tool_calls": None,
498+
"finish_reason": "error",
499+
},
500+
)
501+
except Exception as e:
502+
logger.warning(
503+
"Gemini returned an error during chat_with_tools_stream: %s", str(e)
504+
)
505+
yield (
506+
"message",
507+
{
508+
"content": None,
509+
"tool_calls": None,
510+
"finish_reason": "error",
511+
},
512+
)

frigate/genai/llama_cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
102102

103103
def get_context_size(self) -> int:
104104
"""Get the context window size for llama.cpp."""
105-
return self.provider_options.get("context_size", 4096)
105+
return int(self.provider_options.get("context_size", 4096))
106106

107107
def _build_payload(
108108
self,

0 commit comments

Comments
 (0)