Skip to content

Commit 635dc72

Browse files
Merge pull request #14604 from Sameerlite/litellm_gemini_api_base_update
Litellm gemini api base update
2 parents 701b4ff + 5b896d2 commit 635dc72

File tree

3 files changed

+315
-9
lines changed

3 files changed

+315
-9
lines changed

litellm/llms/vertex_ai/vertex_llm_base.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def get_complete_vertex_url(
239239
stream=stream,
240240
auth_header=None,
241241
url=default_api_base,
242+
model=model,
242243
)
243244
return api_base
244245

@@ -292,6 +293,7 @@ def _check_custom_proxy(
292293
stream: Optional[bool],
293294
auth_header: Optional[str],
294295
url: str,
296+
model: Optional[str] = None,
295297
) -> Tuple[Optional[str], str]:
296298
"""
297299
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
@@ -301,7 +303,12 @@ def _check_custom_proxy(
301303
"""
302304
if api_base:
303305
if custom_llm_provider == "gemini":
304-
url = "{}:{}".format(api_base, endpoint)
306+
# For Gemini (Google AI Studio), construct the full path like other providers
307+
if model is None:
308+
raise ValueError(
309+
"Model parameter is required for Gemini custom API base URLs"
310+
)
311+
url = "{}/models/{}:{}".format(api_base, model, endpoint)
305312
if gemini_api_key is None:
306313
raise ValueError(
307314
"Missing gemini_api_key, please set `GEMINI_API_KEY`"
@@ -373,6 +380,7 @@ def _get_token_and_url(
373380
endpoint=endpoint,
374381
stream=stream,
375382
url=url,
383+
model=model,
376384
)
377385

378386
def _handle_reauthentication(
@@ -384,31 +392,31 @@ def _handle_reauthentication(
384392
) -> Tuple[str, str]:
385393
"""
386394
Handle reauthentication when credentials refresh fails.
387-
395+
388396
This method clears the cached credentials and attempts to reload them once.
389397
It should only be called when "Reauthentication is needed" error occurs.
390-
398+
391399
Args:
392400
credentials: The original credentials
393401
project_id: The project ID
394402
credential_cache_key: The cache key to clear
395403
error: The original error that triggered reauthentication
396-
404+
397405
Returns:
398406
Tuple of (access_token, project_id)
399-
407+
400408
Raises:
401409
The original error if reauthentication fails
402410
"""
403411
verbose_logger.debug(
404412
f"Handling reauthentication for project_id: {project_id}. "
405413
f"Clearing cache and retrying once."
406414
)
407-
415+
408416
# Clear the cached credentials
409417
if credential_cache_key in self._credentials_project_mapping:
410418
del self._credentials_project_mapping[credential_cache_key]
411-
419+
412420
# Retry once with _retry_reauth=True to prevent infinite recursion
413421
try:
414422
return self.get_access_token(
@@ -438,12 +446,12 @@ def get_access_token(
438446
3. Check if loaded credentials have expired
439447
4. If expired, refresh credentials
440448
5. Return access token and project id
441-
449+
442450
Args:
443451
credentials: The credentials to use for authentication
444452
project_id: The Google Cloud project ID
445453
_retry_reauth: Internal flag to prevent infinite recursion during reauthentication
446-
454+
447455
Returns:
448456
Tuple of (access_token, project_id)
449457
"""

tests/proxy_unit_tests/test_google_gemini_proxy_request.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,130 @@ async def test_generationconfig_to_config_mapping(sample_request_payload):
345345
print("✅ generationConfig to config mapping test passed")
346346

347347

348+
@pytest.mark.asyncio
349+
async def test_gemini_custom_api_base_proxy_integration():
350+
"""
351+
Test that Gemini models work correctly with custom API base URLs in proxy context.
352+
353+
This test verifies that when a custom api_base is provided for Gemini models,
354+
the URL is correctly constructed using the _check_custom_proxy method.
355+
"""
356+
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
357+
358+
# Test the _check_custom_proxy method directly
359+
vertex_base = VertexBase()
360+
361+
# Test case 1: Custom API base for Gemini
362+
custom_api_base = "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta"
363+
model = "gemini-2.5-flash-lite"
364+
endpoint = "generateContent"
365+
366+
auth_header, result_url = vertex_base._check_custom_proxy(
367+
api_base=custom_api_base,
368+
custom_llm_provider="gemini",
369+
gemini_api_key="test-api-key",
370+
endpoint=endpoint,
371+
stream=False,
372+
auth_header=None,
373+
url=f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{endpoint}",
374+
model=model,
375+
)
376+
377+
# Verify the URL is correctly constructed
378+
expected_url = f"{custom_api_base}/models/{model}:{endpoint}"
379+
assert result_url == expected_url, f"Expected {expected_url}, got {result_url}"
380+
381+
# Verify the auth header is set to the API key
382+
assert auth_header == "test-api-key", f"Expected 'test-api-key', got {auth_header}"
383+
384+
print(f"✅ Custom API base URL construction test passed: {result_url}")
385+
386+
# Test case 2: Custom API base with streaming
387+
auth_header_streaming, result_url_streaming = vertex_base._check_custom_proxy(
388+
api_base=custom_api_base,
389+
custom_llm_provider="gemini",
390+
gemini_api_key="test-api-key",
391+
endpoint=endpoint,
392+
stream=True,
393+
auth_header=None,
394+
url=f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{endpoint}",
395+
model=model,
396+
)
397+
398+
# Verify streaming URL has ?alt=sse parameter
399+
expected_streaming_url = f"{custom_api_base}/models/{model}:{endpoint}?alt=sse"
400+
assert result_url_streaming == expected_streaming_url, f"Expected {expected_streaming_url}, got {result_url_streaming}"
401+
402+
print(f"✅ Custom API base streaming URL test passed: {result_url_streaming}")
403+
404+
# Test case 3: Error handling - missing API key
405+
with pytest.raises(ValueError, match="Missing gemini_api_key"):
406+
vertex_base._check_custom_proxy(
407+
api_base=custom_api_base,
408+
custom_llm_provider="gemini",
409+
gemini_api_key=None, # Missing API key
410+
endpoint=endpoint,
411+
stream=False,
412+
auth_header=None,
413+
url=f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{endpoint}",
414+
model=model,
415+
)
416+
417+
print("✅ Missing API key error handling test passed")
418+
419+
420+
@pytest.mark.asyncio
421+
async def test_gemini_proxy_config_with_custom_api_base():
422+
"""
423+
Test that proxy configuration correctly handles custom API base for Gemini models.
424+
425+
This test simulates the proxy configuration scenario where a model is configured
426+
with a custom api_base in the config.yaml file.
427+
"""
428+
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
429+
430+
# Simulate proxy configuration
431+
model_config = {
432+
"model_name": "byok-gemini/*",
433+
"litellm_params": {
434+
"model": "gemini/*",
435+
"api_key": "dummy-key-for-testing",
436+
"api_base": "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta"
437+
}
438+
}
439+
440+
vertex_base = VertexBase()
441+
442+
# Test with different Gemini models
443+
test_models = [
444+
"gemini-2.5-flash-lite",
445+
"gemini-2.5-pro",
446+
"gemini-1.5-flash",
447+
"gemini-1.5-pro"
448+
]
449+
450+
for model in test_models:
451+
# Test generateContent endpoint
452+
auth_header, result_url = vertex_base._check_custom_proxy(
453+
api_base=model_config["litellm_params"]["api_base"],
454+
custom_llm_provider="gemini",
455+
gemini_api_key=model_config["litellm_params"]["api_key"],
456+
endpoint="generateContent",
457+
stream=False,
458+
auth_header=None,
459+
url=f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
460+
model=model,
461+
)
462+
463+
expected_url = f"{model_config['litellm_params']['api_base']}/models/{model}:generateContent"
464+
assert result_url == expected_url, f"Expected {expected_url}, got {result_url} for model {model}"
465+
assert auth_header == model_config["litellm_params"]["api_key"], f"Expected API key, got {auth_header} for model {model}"
466+
467+
print(f"✅ Model {model} configuration test passed: {result_url}")
468+
469+
print("✅ Proxy configuration with custom API base test passed")
470+
471+
348472
if __name__ == "__main__":
349473
# Run the tests
350474
pytest.main([__file__, "-v"])

tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,177 @@ def test_get_api_base(self, api_base, vertex_location, expected):
704704
vertex_base.get_api_base(api_base=api_base, vertex_location=vertex_location)
705705
== expected
706706
), f"Expected {expected} with api_base {api_base} and vertex_location {vertex_location}"
707+
708+
@pytest.mark.parametrize(
709+
"api_base, custom_llm_provider, gemini_api_key, endpoint, stream, auth_header, url, model, expected_auth_header, expected_url",
710+
[
711+
# Test case 1: Gemini with custom API base
712+
(
713+
"https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
714+
"gemini",
715+
"test-api-key",
716+
"generateContent",
717+
False,
718+
None,
719+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
720+
"gemini-2.5-flash-lite",
721+
"test-api-key",
722+
"https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent"
723+
),
724+
# Test case 2: Gemini with custom API base and streaming
725+
(
726+
"https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
727+
"gemini",
728+
"test-api-key",
729+
"generateContent",
730+
True,
731+
None,
732+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
733+
"gemini-2.5-flash-lite",
734+
"test-api-key",
735+
"https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent?alt=sse"
736+
),
737+
# Test case 3: Non-Gemini provider with custom API base
738+
(
739+
"https://custom-vertex-api.com",
740+
"vertex_ai",
741+
None,
742+
"generateContent",
743+
False,
744+
"Bearer token123",
745+
"https://aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent",
746+
"gemini-pro",
747+
"Bearer token123",
748+
"https://custom-vertex-api.com:generateContent"
749+
),
750+
# Test case 4: No API base provided (should return original values)
751+
(
752+
None,
753+
"gemini",
754+
"test-api-key",
755+
"generateContent",
756+
False,
757+
"Bearer token123",
758+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
759+
"gemini-2.5-flash-lite",
760+
"Bearer token123",
761+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent"
762+
),
763+
# Test case 5: Gemini without API key (should raise ValueError)
764+
(
765+
"https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
766+
"gemini",
767+
None,
768+
"generateContent",
769+
False,
770+
None,
771+
"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
772+
"gemini-2.5-flash-lite",
773+
None, # This should raise an exception
774+
None
775+
),
776+
],
777+
)
778+
def test_check_custom_proxy(
779+
self,
780+
api_base,
781+
custom_llm_provider,
782+
gemini_api_key,
783+
endpoint,
784+
stream,
785+
auth_header,
786+
url,
787+
model,
788+
expected_auth_header,
789+
expected_url
790+
):
791+
"""Test the _check_custom_proxy method for handling custom API base URLs"""
792+
vertex_base = VertexBase()
793+
794+
if custom_llm_provider == "gemini" and api_base and gemini_api_key is None:
795+
# Test case 5: Should raise ValueError for Gemini without API key
796+
with pytest.raises(ValueError, match="Missing gemini_api_key"):
797+
vertex_base._check_custom_proxy(
798+
api_base=api_base,
799+
custom_llm_provider=custom_llm_provider,
800+
gemini_api_key=gemini_api_key,
801+
endpoint=endpoint,
802+
stream=stream,
803+
auth_header=auth_header,
804+
url=url,
805+
model=model,
806+
)
807+
else:
808+
# Test cases 1-4: Should work correctly
809+
result_auth_header, result_url = vertex_base._check_custom_proxy(
810+
api_base=api_base,
811+
custom_llm_provider=custom_llm_provider,
812+
gemini_api_key=gemini_api_key,
813+
endpoint=endpoint,
814+
stream=stream,
815+
auth_header=auth_header,
816+
url=url,
817+
model=model,
818+
)
819+
820+
assert result_auth_header == expected_auth_header, f"Expected auth_header {expected_auth_header}, got {result_auth_header}"
821+
assert result_url == expected_url, f"Expected URL {expected_url}, got {result_url}"
822+
823+
def test_check_custom_proxy_gemini_url_construction(self):
824+
"""Test that Gemini URLs are constructed correctly with custom API base"""
825+
vertex_base = VertexBase()
826+
827+
# Test various Gemini models with custom API base
828+
test_cases = [
829+
("gemini-2.5-flash-lite", "generateContent", "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent"),
830+
("gemini-2.5-pro", "generateContent", "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent"),
831+
("gemini-1.5-flash", "streamGenerateContent", "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent"),
832+
]
833+
834+
for model, endpoint, expected_url in test_cases:
835+
_, result_url = vertex_base._check_custom_proxy(
836+
api_base="https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
837+
custom_llm_provider="gemini",
838+
gemini_api_key="test-api-key",
839+
endpoint=endpoint,
840+
stream=False,
841+
auth_header=None,
842+
url=f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{endpoint}",
843+
model=model,
844+
)
845+
846+
assert result_url == expected_url, f"Expected {expected_url}, got {result_url} for model {model}"
847+
848+
def test_check_custom_proxy_streaming_parameter(self):
849+
"""Test that streaming parameter correctly adds ?alt=sse to URLs"""
850+
vertex_base = VertexBase()
851+
852+
# Test with streaming enabled
853+
_, result_url_streaming = vertex_base._check_custom_proxy(
854+
api_base="https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
855+
custom_llm_provider="gemini",
856+
gemini_api_key="test-api-key",
857+
endpoint="generateContent",
858+
stream=True,
859+
auth_header=None,
860+
url="https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
861+
model="gemini-2.5-flash-lite",
862+
)
863+
864+
expected_streaming_url = "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent?alt=sse"
865+
assert result_url_streaming == expected_streaming_url, f"Expected {expected_streaming_url}, got {result_url_streaming}"
866+
867+
# Test with streaming disabled
868+
_, result_url_no_streaming = vertex_base._check_custom_proxy(
869+
api_base="https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta",
870+
custom_llm_provider="gemini",
871+
gemini_api_key="test-api-key",
872+
endpoint="generateContent",
873+
stream=False,
874+
auth_header=None,
875+
url="https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent",
876+
model="gemini-2.5-flash-lite",
877+
)
878+
879+
expected_no_streaming_url = "https://proxy.zapier.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent"
880+
assert result_url_no_streaming == expected_no_streaming_url, f"Expected {expected_no_streaming_url}, got {result_url_no_streaming}"

0 commit comments

Comments
 (0)