Skip to content

Commit 1a123b2

Browse files
Litellm gemini cli bug fix (#14451)
* Fix gemini cli error * Add reasoning request support * Added better handling * remove other PR code * refactored code for better structure following --------- Co-authored-by: [email protected] <[email protected]>
1 parent 8b338a4 commit 1a123b2

File tree

6 files changed

+431
-13
lines changed

6 files changed

+431
-13
lines changed

litellm/google_genai/main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ async def agenerate_content(
224224
loop = asyncio.get_event_loop()
225225
kwargs["agenerate_content"] = True
226226

227+
# Handle generationConfig parameter from kwargs for backward compatibility
228+
if "generationConfig" in kwargs and config is None:
229+
config = kwargs.pop("generationConfig")
227230
# get custom llm provider so we can use this for mapping exceptions
228231
if custom_llm_provider is None:
229232
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
@@ -288,6 +291,9 @@ def generate_content(
288291
try:
289292
_is_async = kwargs.pop("agenerate_content", False) is True
290293

294+
# Handle generationConfig parameter from kwargs for backward compatibility
295+
if "generationConfig" in kwargs and config is None:
296+
config = kwargs.pop("generationConfig")
291297
# Check for mock response first
292298
litellm_params = GenericLiteLLMParams(**kwargs)
293299
if litellm_params.mock_response and isinstance(
@@ -374,6 +380,9 @@ async def agenerate_content_stream(
374380
try:
375381
kwargs["agenerate_content_stream"] = True
376382

383+
# Handle generationConfig parameter from kwargs for backward compatibility
384+
if "generationConfig" in kwargs and config is None:
385+
config = kwargs.pop("generationConfig")
377386
# get custom llm provider so we can use this for mapping exceptions
378387
if custom_llm_provider is None:
379388
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
@@ -461,6 +470,9 @@ def generate_content_stream(
461470
# Remove any async-related flags since this is the sync function
462471
_is_async = kwargs.pop("agenerate_content_stream", False)
463472

473+
# Handle generationConfig parameter from kwargs for backward compatibility
474+
if "generationConfig" in kwargs and config is None:
475+
config = kwargs.pop("generationConfig")
464476
# Setup the call
465477
setup_result = GenerateContentHelper.setup_generate_content_call(
466478
model=model,

litellm/llms/gemini/count_tokens/handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,3 @@ async def acount_tokens(
136136
except Exception as e:
137137
error_msg = f"Unexpected error during token counting: {str(e)}"
138138
raise Exception(error_msg) from e
139-
Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Transformation for Calling Google models in their native format.
33
"""
4-
from typing import Literal, Optional, Union
4+
from typing import Dict, Literal, Optional, Union
55

66
from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig
77
from litellm.types.router import GenericLiteLLMParams
@@ -11,20 +11,20 @@ class VertexAIGoogleGenAIConfig(GoogleGenAIConfig):
1111
"""
1212
Configuration for calling Google models in their native format.
1313
"""
14+
1415
HEADER_NAME = "Authorization"
1516
BEARER_PREFIX = "Bearer"
16-
17+
1718
@property
1819
def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]:
1920
return "vertex_ai"
20-
2121

2222
def validate_environment(
23-
self,
23+
self,
2424
api_key: Optional[str],
2525
headers: Optional[dict],
2626
model: str,
27-
litellm_params: Optional[Union[GenericLiteLLMParams, dict]]
27+
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
2828
) -> dict:
2929
default_headers = {
3030
"Content-Type": "application/json",
@@ -36,4 +36,65 @@ def validate_environment(
3636
default_headers.update(headers)
3737

3838
return default_headers
39-
39+
40+
def _camel_to_snake(self, camel_str: str) -> str:
41+
"""Convert camelCase to snake_case"""
42+
import re
43+
44+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
45+
46+
def map_generate_content_optional_params(
47+
self,
48+
generate_content_config_dict,
49+
model: str,
50+
):
51+
"""
52+
Map Google GenAI parameters to provider-specific format.
53+
54+
Args:
55+
generate_content_optional_params: Optional parameters for generate content
56+
model: The model name
57+
58+
Returns:
59+
Mapped parameters for the provider
60+
"""
61+
from litellm.types.google_genai.main import GenerateContentConfigDict
62+
63+
_generate_content_config_dict = GenerateContentConfigDict()
64+
65+
for param, value in generate_content_config_dict.items():
66+
camel_case_key = self._camel_to_snake(param)
67+
_generate_content_config_dict[camel_case_key] = value
68+
return dict(_generate_content_config_dict)
69+
70+
def transform_generate_content_request(
71+
self,
72+
model: str,
73+
contents: any,
74+
tools: Optional[any],
75+
generate_content_config_dict: Dict,
76+
system_instruction: Optional[any] = None,
77+
) -> dict:
78+
"""
79+
Transform the generate content request for Vertex AI.
80+
Since Vertex AI natively supports Google GenAI format, we can pass most fields directly.
81+
"""
82+
# Build the request in Google GenAI format that Vertex AI expects
83+
result = {
84+
"model": model,
85+
"contents": contents,
86+
}
87+
88+
# Add tools if provided
89+
if tools:
90+
result["tools"] = tools
91+
92+
# Add systemInstruction if provided
93+
if system_instruction:
94+
result["systemInstruction"] = system_instruction
95+
96+
# Handle generationConfig - Vertex AI expects it in the same format
97+
if generate_content_config_dict:
98+
result["generationConfig"] = generate_content_config_dict
99+
100+
return result

litellm/proxy/route_llm_request.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def route_request(
8585
"""
8686
team_id = get_team_id_from_data(data)
8787
router_model_names = llm_router.model_names if llm_router is not None else []
88+
8889
if "api_key" in data or "api_base" in data:
8990
if llm_router is not None:
9091
return getattr(llm_router, f"{route_type}")(**data)
@@ -123,24 +124,20 @@ async def route_request(
123124
data["model"] in router_model_names
124125
or data["model"] in llm_router.get_model_ids()
125126
):
126-
127127
return getattr(llm_router, f"{route_type}")(**data)
128128

129129
elif (
130130
llm_router.model_group_alias is not None
131131
and data["model"] in llm_router.model_group_alias
132132
):
133-
134133
return getattr(llm_router, f"{route_type}")(**data)
135134

136135
elif data["model"] in llm_router.deployment_names:
137-
138136
return getattr(llm_router, f"{route_type}")(
139137
**data, specific_deployment=True
140138
)
141139

142140
elif data["model"] not in router_model_names:
143-
144141
if llm_router.router_general_settings.pass_through_all_models:
145142
return getattr(litellm, f"{route_type}")(**data)
146143
elif (
@@ -162,7 +159,6 @@ async def route_request(
162159
elif user_model is not None:
163160
return getattr(litellm, f"{route_type}")(**data)
164161
elif route_type == "allm_passthrough_route":
165-
166162
return getattr(litellm, f"{route_type}")(**data)
167163

168164
# if no route found then it's a bad request

0 commit comments

Comments
 (0)