Skip to content

Commit 9291daa

Browse files
GWealecopybara-github
authored andcommitted
chore: Add warning for using Gemini models via LiteLLM
Recommend to use Gemini outside of LiteLLM PiperOrigin-RevId: 800971705
1 parent fcd748e commit 9291daa

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import base64
1818
import json
1919
import logging
20+
import os
21+
import re
2022
from typing import Any
2123
from typing import AsyncGenerator
2224
from typing import cast
@@ -29,6 +31,7 @@
2931
from typing import Tuple
3032
from typing import TypedDict
3133
from typing import Union
34+
import warnings
3235

3336
from google.genai import types
3437
import litellm
@@ -672,6 +675,67 @@ def _build_request_log(req: LlmRequest) -> str:
672675
"""
673676

674677

678+
def _is_litellm_gemini_model(model_string: str) -> bool:
679+
"""Check if the model is a Gemini model accessed via LiteLLM.
680+
681+
Args:
682+
model_string: A LiteLLM model string (e.g., "gemini/gemini-2.5-pro" or
683+
"vertex_ai/gemini-1.5-flash")
684+
685+
Returns:
686+
True if it's a Gemini model accessed via LiteLLM, False otherwise
687+
"""
688+
# Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI).
689+
pattern = r"^(gemini|vertex_ai)/gemini-"
690+
return bool(re.match(pattern, model_string))
691+
692+
693+
def _extract_gemini_model_from_litellm(litellm_model: str) -> str:
694+
"""Extract the pure Gemini model name from a LiteLLM model string.
695+
696+
Args:
697+
litellm_model: LiteLLM model string like "gemini/gemini-2.5-pro"
698+
699+
Returns:
700+
Pure Gemini model name like "gemini-2.5-pro"
701+
"""
702+
# Remove LiteLLM provider prefix
703+
if "/" in litellm_model:
704+
return litellm_model.split("/", 1)[1]
705+
return litellm_model
706+
707+
708+
def _warn_gemini_via_litellm(model_string: str) -> None:
709+
"""Warn if Gemini is being used via LiteLLM.
710+
711+
This function logs a warning suggesting users use Gemini directly rather than
712+
through LiteLLM for better performance and features.
713+
714+
Args:
715+
model_string: The LiteLLM model string to check
716+
"""
717+
if not _is_litellm_gemini_model(model_string):
718+
return
719+
720+
# Check if warning should be suppressed via environment variable
721+
if os.environ.get(
722+
"ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", ""
723+
).strip().lower() in ("1", "true", "yes", "on"):
724+
return
725+
726+
warnings.warn(
727+
f"[GEMINI_VIA_LITELLM] {model_string}: You are using Gemini via LiteLLM."
728+
" For better performance, reliability, and access to latest features,"
729+
" consider using Gemini directly through ADK's native Gemini"
730+
f" integration. Replace LiteLlm(model='{model_string}') with"
731+
f" Gemini(model='{_extract_gemini_model_from_litellm(model_string)}')."
732+
" Set ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS=true to suppress this"
733+
" warning.",
734+
category=UserWarning,
735+
stacklevel=3,
736+
)
737+
738+
675739
class LiteLlm(BaseLlm):
676740
"""Wrapper around litellm.
677741
@@ -708,6 +772,8 @@ def __init__(self, model: str, **kwargs):
708772
**kwargs: Additional arguments to pass to the litellm completion api.
709773
"""
710774
super().__init__(model=model, **kwargs)
775+
# Warn if using Gemini via LiteLLM
776+
_warn_gemini_via_litellm(model)
711777
self._additional_args = kwargs
712778
# preventing generation call with llm_client
713779
# and overriding messages, tools and stream which are managed internally

tests/unittests/models/test_litellm.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
from unittest.mock import AsyncMock
1818
from unittest.mock import Mock
19+
import warnings
1920

2021
from google.adk.models.lite_llm import _content_to_message_param
2122
from google.adk.models.lite_llm import _function_declaration_to_tool_param
@@ -1574,3 +1575,51 @@ def test_get_completion_inputs_generation_params():
15741575
# Should not include max_output_tokens
15751576
assert "max_output_tokens" not in generation_params
15761577
assert "stop_sequences" not in generation_params
1578+
1579+
1580+
def test_gemini_via_litellm_warning(monkeypatch):
1581+
"""Test that Gemini via LiteLLM shows warning."""
1582+
# Ensure environment variable is not set
1583+
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
1584+
with warnings.catch_warnings(record=True) as w:
1585+
warnings.simplefilter("always")
1586+
# Test with Google AI Studio Gemini via LiteLLM
1587+
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
1588+
assert len(w) == 1
1589+
assert issubclass(w[0].category, UserWarning)
1590+
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
1591+
assert "better performance" in str(w[0].message)
1592+
assert "gemini-2.5-pro-exp-03-25" in str(w[0].message)
1593+
assert "ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS" in str(w[0].message)
1594+
1595+
1596+
def test_gemini_via_litellm_warning_vertex_ai(monkeypatch):
1597+
"""Test that Vertex AI Gemini via LiteLLM shows warning."""
1598+
# Ensure environment variable is not set
1599+
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
1600+
with warnings.catch_warnings(record=True) as w:
1601+
warnings.simplefilter("always")
1602+
# Test with Vertex AI Gemini via LiteLLM
1603+
LiteLlm(model="vertex_ai/gemini-1.5-flash")
1604+
assert len(w) == 1
1605+
assert issubclass(w[0].category, UserWarning)
1606+
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
1607+
assert "vertex_ai/gemini-1.5-flash" in str(w[0].message)
1608+
1609+
1610+
def test_gemini_via_litellm_warning_suppressed(monkeypatch):
1611+
"""Test that Gemini via LiteLLM warning can be suppressed."""
1612+
monkeypatch.setenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", "true")
1613+
with warnings.catch_warnings(record=True) as w:
1614+
warnings.simplefilter("always")
1615+
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
1616+
assert len(w) == 0
1617+
1618+
1619+
def test_non_gemini_litellm_no_warning():
1620+
"""Test that non-Gemini models via LiteLLM don't show warning."""
1621+
with warnings.catch_warnings(record=True) as w:
1622+
warnings.simplefilter("always")
1623+
# Test with non-Gemini model
1624+
LiteLlm(model="openai/gpt-4o")
1625+
assert len(w) == 0

0 commit comments

Comments
 (0)