Skip to content

Commit 54052d3

Browse files
MarkDaoustcopybara-github
authored andcommitted
chore: Move prepare_options logic into interactions client.
PiperOrigin-RevId: 845944715
1 parent d98c757 commit 54052d3

File tree

6 files changed

+176
-62
lines changed

6 files changed

+176
-62
lines changed

google/genai/_interactions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
5555
from ._utils._logs import setup_logging as _setup_logging
56+
from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter
5657

5758
__all__ = [
5859
"types",
@@ -96,6 +97,8 @@
9697
"DefaultHttpxClient",
9798
"DefaultAsyncHttpxClient",
9899
"DefaultAioHttpClient",
100+
"AsyncGeminiNextGenAPIClientAdapter",
101+
"GeminiNextGenAPIClientAdapter"
99102
]
100103

101104
if not _t.TYPE_CHECKING:

google/genai/_interactions/_client.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from ._utils import is_given, get_async_library
3939
from ._compat import cached_property
40+
from ._models import FinalRequestOptions
4041
from ._version import __version__
4142
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
4243
from ._exceptions import APIStatusError
@@ -45,6 +46,7 @@
4546
SyncAPIClient,
4647
AsyncAPIClient,
4748
)
49+
from ._client_adapter import GeminiNextGenAPIClientAdapter, AsyncGeminiNextGenAPIClientAdapter
4850

4951
if TYPE_CHECKING:
5052
from .resources import interactions
@@ -66,6 +68,7 @@ class GeminiNextGenAPIClient(SyncAPIClient):
6668
# client options
6769
api_key: str | None
6870
api_version: str
71+
client_adapter: GeminiNextGenAPIClientAdapter | None
6972

7073
def __init__(
7174
self,
@@ -81,6 +84,7 @@ def __init__(
8184
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
8285
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
8386
http_client: httpx.Client | None = None,
87+
client_adapter: GeminiNextGenAPIClientAdapter | None = None,
8488
# Enable or disable schema validation for data returned by the API.
8589
# When enabled an error APIResponseValidationError is raised
8690
# if the API responds with invalid data for the expected schema.
@@ -108,6 +112,8 @@ def __init__(
108112
if base_url is None:
109113
base_url = f"https://generativelanguage.googleapis.com"
110114

115+
self.client_adapter = client_adapter
116+
111117
super().__init__(
112118
version=__version__,
113119
base_url=base_url,
@@ -159,13 +165,35 @@ def default_headers(self) -> dict[str, str | Omit]:
159165

160166
@override
161167
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
162-
if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
168+
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
169+
return
170+
if self.api_key and headers.get("x-goog-api-key"):
171+
return
172+
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
163173
return
164174

165175
raise TypeError(
166176
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
167177
)
168-
178+
179+
@override
180+
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
181+
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
182+
return options
183+
184+
headers = options.headers or {}
185+
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
186+
if has_auth:
187+
return options
188+
189+
adapted_headers = self.client_adapter.get_auth_headers()
190+
if adapted_headers:
191+
options.headers = {
192+
**adapted_headers,
193+
**headers
194+
}
195+
return options
196+
169197
def copy(
170198
self,
171199
*,
@@ -179,6 +207,7 @@ def copy(
179207
set_default_headers: Mapping[str, str] | None = None,
180208
default_query: Mapping[str, object] | None = None,
181209
set_default_query: Mapping[str, object] | None = None,
210+
client_adapter: GeminiNextGenAPIClientAdapter | None = None,
182211
_extra_kwargs: Mapping[str, Any] = {},
183212
) -> Self:
184213
"""
@@ -212,6 +241,7 @@ def copy(
212241
max_retries=max_retries if is_given(max_retries) else self.max_retries,
213242
default_headers=headers,
214243
default_query=params,
244+
client_adapter=self.client_adapter or client_adapter,
215245
**_extra_kwargs,
216246
)
217247

@@ -260,6 +290,7 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient):
260290
# client options
261291
api_key: str | None
262292
api_version: str
293+
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None
263294

264295
def __init__(
265296
self,
@@ -275,6 +306,7 @@ def __init__(
275306
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
276307
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
277308
http_client: httpx.AsyncClient | None = None,
309+
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None,
278310
# Enable or disable schema validation for data returned by the API.
279311
# When enabled an error APIResponseValidationError is raised
280312
# if the API responds with invalid data for the expected schema.
@@ -302,6 +334,8 @@ def __init__(
302334
if base_url is None:
303335
base_url = f"https://generativelanguage.googleapis.com"
304336

337+
self.client_adapter = client_adapter
338+
305339
super().__init__(
306340
version=__version__,
307341
base_url=base_url,
@@ -353,12 +387,34 @@ def default_headers(self) -> dict[str, str | Omit]:
353387

354388
@override
355389
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
356-
if headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
390+
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
391+
return
392+
if self.api_key and headers.get("x-goog-api-key"):
393+
return
394+
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
357395
return
358396

359397
raise TypeError(
360398
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
361399
)
400+
401+
@override
402+
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
403+
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
404+
return options
405+
406+
headers = options.headers or {}
407+
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
408+
if has_auth:
409+
return options
410+
411+
adapted_headers = await self.client_adapter.async_get_auth_headers()
412+
if adapted_headers:
413+
options.headers = {
414+
**adapted_headers,
415+
**headers
416+
}
417+
return options
362418

363419
def copy(
364420
self,
@@ -373,6 +429,7 @@ def copy(
373429
set_default_headers: Mapping[str, str] | None = None,
374430
default_query: Mapping[str, object] | None = None,
375431
set_default_query: Mapping[str, object] | None = None,
432+
client_adapter: AsyncGeminiNextGenAPIClientAdapter | None = None,
376433
_extra_kwargs: Mapping[str, Any] = {},
377434
) -> Self:
378435
"""
@@ -406,6 +463,7 @@ def copy(
406463
max_retries=max_retries if is_given(max_retries) else self.max_retries,
407464
default_headers=headers,
408465
default_query=params,
466+
client_adapter=self.client_adapter or client_adapter,
409467
**_extra_kwargs,
410468
)
411469

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
from __future__ import annotations
17+
18+
from abc import ABC, abstractmethod
19+
20+
__all__ = [
21+
"GeminiNextGenAPIClientAdapter",
22+
"AsyncGeminiNextGenAPIClientAdapter"
23+
]
24+
25+
class BaseGeminiNextGenAPIClientAdapter(ABC):
26+
@abstractmethod
27+
def is_vertex_ai(self) -> bool:
28+
...
29+
30+
@abstractmethod
31+
def get_project(self) -> str | None:
32+
...
33+
34+
@abstractmethod
35+
def get_location(self) -> str | None:
36+
...
37+
38+
39+
class AsyncGeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
40+
@abstractmethod
41+
async def async_get_auth_headers(self) -> dict[str, str] | None:
42+
...
43+
44+
45+
class GeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
46+
@abstractmethod
47+
def get_auth_headers(self) -> dict[str, str] | None:
48+
...

0 commit comments

Comments
 (0)