Skip to content

Commit 6191412

Browse files
xuanyang15copybara-github
authored andcommitted
fix: keep existing header values while merging tracking headers for llm_request.config.http_options in Gemini.generate_content_async
PiperOrigin-RevId: 789013693
1 parent 3be1bb3 commit 6191412

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/google/adk/models/google_llm.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ async def generate_content_async(
122122
if llm_request.config:
123123
if not llm_request.config.http_options:
124124
llm_request.config.http_options = types.HttpOptions()
125-
if not llm_request.config.http_options.headers:
126-
llm_request.config.http_options.headers = {}
127-
llm_request.config.http_options.headers.update(self._tracking_headers)
125+
llm_request.config.http_options.headers = self._merge_tracking_headers(
126+
llm_request.config.http_options.headers
127+
)
128128

129129
if stream:
130130
responses = await self.api_client.aio.models.generate_content_stream(
@@ -336,6 +336,23 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
336336
llm_request.config.system_instruction = None
337337
await self._adapt_computer_use_tool(llm_request)
338338

339+
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
340+
"""Merge tracking headers to the given headers."""
341+
headers = headers or {}
342+
for key, tracking_header_value in self._tracking_headers.items():
343+
custom_value = headers.get(key, None)
344+
if not custom_value:
345+
headers[key] = tracking_header_value
346+
continue
347+
348+
# Merge tracking headers with existing headers and avoid duplicates.
349+
value_parts = tracking_header_value.split(' ')
350+
for custom_value_part in custom_value.split(' '):
351+
if custom_value_part not in value_parts:
352+
value_parts.append(custom_value_part)
353+
headers[key] = ' '.join(value_parts)
354+
return headers
355+
339356

340357
def _build_function_declaration_log(
341358
func_decl: types.FunctionDeclaration,

tests/unittests/models/test_google_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ async def mock_coro():
403403

404404
for key, value in config_arg.http_options.headers.items():
405405
if key in gemini_llm._tracking_headers:
406-
assert value == gemini_llm._tracking_headers[key]
406+
assert value == gemini_llm._tracking_headers[key] + " custom"
407407
else:
408408
assert value == custom_headers[key]
409409

0 commit comments

Comments
 (0)