Skip to content

Commit cd25782

Browse files
authored
fix passing headers for gemini (#15231)
1 parent 27c64c9 commit cd25782

File tree

2 files changed

+174
-2
lines changed

2 files changed

+174
-2
lines changed

litellm/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,7 +2872,7 @@ def completion( # type: ignore # noqa: PLR0915
28722872
custom_llm_provider=custom_llm_provider, # type: ignore
28732873
client=client,
28742874
api_base=api_base,
2875-
extra_headers=extra_headers,
2875+
extra_headers=headers,
28762876
)
28772877

28782878
elif custom_llm_provider == "vertex_ai":
@@ -2941,7 +2941,7 @@ def completion( # type: ignore # noqa: PLR0915
29412941
custom_llm_provider=custom_llm_provider, # type: ignore
29422942
client=client,
29432943
api_base=api_base,
2944-
extra_headers=extra_headers,
2944+
extra_headers=headers,
29452945
)
29462946
elif "openai" in model:
29472947
# Vertex Model Garden - OpenAI compatible models
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Test to verify that custom headers are correctly forwarded to Gemini/Vertex AI API calls.
3+
4+
This test verifies the fix for the issue where headers configured via
5+
forward_client_headers_to_llm_api were not being passed to Gemini/Vertex AI providers.
6+
"""
7+
8+
import pytest
9+
from unittest.mock import Mock, patch, MagicMock
10+
import litellm
11+
from litellm import completion
12+
13+
14+
class TestGeminiHeaderForwarding:
15+
"""Test cases for verifying header forwarding to Gemini/Vertex AI."""
16+
17+
@pytest.mark.parametrize(
18+
"custom_llm_provider,model",
19+
[
20+
("gemini", "gemini/gemini-1.5-pro"),
21+
("vertex_ai_beta", "gemini-1.5-pro"),
22+
("vertex_ai", "gemini-1.5-pro"),
23+
],
24+
)
25+
def test_headers_forwarded_to_gemini(self, custom_llm_provider, model):
26+
"""
27+
Test that headers from kwargs are correctly merged and passed to Gemini completion.
28+
29+
This test verifies that when headers are passed via kwargs (as the proxy does when
30+
forward_client_headers_to_llm_api is configured), they are correctly merged with
31+
extra_headers and passed to the Vertex AI completion handler.
32+
"""
33+
messages = [{"role": "user", "content": "Hello"}]
34+
35+
# Headers that would be set by the proxy when forwarding client headers
36+
custom_headers = {
37+
"X-Custom-Header": "CustomValue",
38+
"X-BYOK-Token": "secret-token",
39+
}
40+
41+
# Mock the vertex completion handler
42+
with patch(
43+
"litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.VertexLLM.completion"
44+
) as mock_vertex_completion:
45+
# Configure the mock to return a proper response
46+
mock_response = Mock()
47+
mock_response.choices = [Mock()]
48+
mock_response.choices[0].message.content = "Hello back!"
49+
mock_vertex_completion.return_value = mock_response
50+
51+
try:
52+
# Call completion with custom headers via kwargs
53+
# This simulates what the proxy does when forward_client_headers_to_llm_api is set
54+
completion(
55+
model=model,
56+
messages=messages,
57+
headers=custom_headers, # This is how proxy passes forwarded headers
58+
custom_llm_provider=custom_llm_provider,
59+
api_key="dummy-key",
60+
)
61+
62+
# Verify that the completion handler was called
63+
assert mock_vertex_completion.called, "Vertex completion handler should be called"
64+
65+
# Get the actual call arguments
66+
call_kwargs = mock_vertex_completion.call_args.kwargs
67+
68+
# Verify that extra_headers parameter contains our custom headers
69+
assert "extra_headers" in call_kwargs, "extra_headers should be passed to completion"
70+
71+
passed_headers = call_kwargs["extra_headers"]
72+
assert passed_headers is not None, "extra_headers should not be None"
73+
74+
# Verify our custom headers are present in the passed headers
75+
for header_key, header_value in custom_headers.items():
76+
assert (
77+
header_key in passed_headers
78+
or header_key.lower() in passed_headers
79+
), f"Header {header_key} should be in extra_headers"
80+
81+
print(f"✓ Test passed for {custom_llm_provider}/{model}")
82+
print(f" Headers correctly forwarded: {passed_headers}")
83+
84+
except Exception as e:
85+
pytest.fail(
86+
f"Failed to forward headers to {custom_llm_provider}/{model}: {str(e)}"
87+
)
88+
89+
def test_extra_headers_and_headers_merge(self):
90+
"""
91+
Test that both extra_headers and headers parameters are correctly merged.
92+
93+
This ensures that headers from kwargs (forwarded by proxy) and extra_headers
94+
(passed explicitly) are both included in the final headers sent to the provider.
95+
"""
96+
messages = [{"role": "user", "content": "Hello"}]
97+
98+
# Headers from proxy (via kwargs["headers"])
99+
proxy_headers = {"X-Forwarded-Header": "ProxyValue"}
100+
101+
# Explicit extra_headers
102+
explicit_headers = {"X-Explicit-Header": "ExplicitValue"}
103+
104+
with patch(
105+
"litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini.VertexLLM.completion"
106+
) as mock_vertex_completion:
107+
mock_response = Mock()
108+
mock_response.choices = [Mock()]
109+
mock_response.choices[0].message.content = "Response"
110+
mock_vertex_completion.return_value = mock_response
111+
112+
try:
113+
completion(
114+
model="gemini/gemini-1.5-pro",
115+
messages=messages,
116+
headers=proxy_headers, # From proxy forwarding
117+
extra_headers=explicit_headers, # Explicitly passed
118+
custom_llm_provider="gemini",
119+
api_key="dummy-key",
120+
)
121+
122+
call_kwargs = mock_vertex_completion.call_args.kwargs
123+
passed_headers = call_kwargs.get("extra_headers", {})
124+
125+
# Both sets of headers should be present
126+
assert (
127+
"X-Forwarded-Header" in passed_headers
128+
or "x-forwarded-header" in passed_headers
129+
), "Proxy forwarded header should be present"
130+
131+
assert (
132+
"X-Explicit-Header" in passed_headers
133+
or "x-explicit-header" in passed_headers
134+
), "Explicitly passed header should be present"
135+
136+
print("✓ Both header sources correctly merged and forwarded")
137+
print(f" Final headers: {passed_headers}")
138+
139+
except Exception as e:
140+
pytest.fail(f"Failed to merge and forward headers: {str(e)}")
141+
142+
143+
if __name__ == "__main__":
144+
# Run the tests
145+
test_instance = TestGeminiHeaderForwarding()
146+
147+
print("\n" + "="*80)
148+
print("Testing Gemini/Vertex AI Header Forwarding")
149+
print("="*80 + "\n")
150+
151+
# Test each provider
152+
for provider, model in [
153+
("gemini", "gemini/gemini-1.5-pro"),
154+
("vertex_ai_beta", "gemini-1.5-pro"),
155+
("vertex_ai", "gemini-1.5-pro"),
156+
]:
157+
print(f"\nTesting {provider}/{model}...")
158+
try:
159+
test_instance.test_headers_forwarded_to_gemini(provider, model)
160+
except Exception as e:
161+
print(f"✗ Test failed: {e}")
162+
163+
print("\n\nTesting header merging...")
164+
try:
165+
test_instance.test_extra_headers_and_headers_merge()
166+
except Exception as e:
167+
print(f"✗ Test failed: {e}")
168+
169+
print("\n" + "="*80)
170+
print("All tests completed!")
171+
print("="*80 + "\n")
172+

0 commit comments

Comments
 (0)