3737)
3838from ._utils import is_given , get_async_library
3939from ._compat import cached_property
40+ from ._models import FinalRequestOptions
4041from ._version import __version__
4142from ._streaming import Stream as Stream , AsyncStream as AsyncStream
4243from ._exceptions import APIStatusError
4546 SyncAPIClient ,
4647 AsyncAPIClient ,
4748)
49+ from ._client_adapter import GeminiNextGenAPIClientAdapter , AsyncGeminiNextGenAPIClientAdapter
4850
4951if TYPE_CHECKING :
5052 from .resources import interactions
@@ -66,6 +68,7 @@ class GeminiNextGenAPIClient(SyncAPIClient):
6668 # client options
6769 api_key : str | None
6870 api_version : str
71+ client_adapter : GeminiNextGenAPIClientAdapter | None
6972
7073 def __init__ (
7174 self ,
@@ -81,6 +84,7 @@ def __init__(
8184 # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
8285 # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
8386 http_client : httpx .Client | None = None ,
87+ client_adapter : GeminiNextGenAPIClientAdapter | None = None ,
8488 # Enable or disable schema validation for data returned by the API.
8589 # When enabled an error APIResponseValidationError is raised
8690 # if the API responds with invalid data for the expected schema.
@@ -108,6 +112,8 @@ def __init__(
108112 if base_url is None :
109113 base_url = f"https://generativelanguage.googleapis.com"
110114
115+ self .client_adapter = client_adapter
116+
111117 super ().__init__ (
112118 version = __version__ ,
113119 base_url = base_url ,
@@ -159,13 +165,35 @@ def default_headers(self) -> dict[str, str | Omit]:
159165
160166 @override
161167 def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
162- if headers .get ("x-goog-api-key" ) or isinstance (custom_headers .get ("x-goog-api-key" ), Omit ):
168+ if headers .get ("Authorization" ) or custom_headers .get ("Authorization" ) or isinstance (custom_headers .get ("Authorization" ), Omit ):
169+ return
170+ if self .api_key and headers .get ("x-goog-api-key" ):
171+ return
172+ if custom_headers .get ("x-goog-api-key" ) or isinstance (custom_headers .get ("x-goog-api-key" ), Omit ):
163173 return
164174
165175 raise TypeError (
166176 '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
167177 )
168-
178+
179+ @override
180+ def _prepare_options (self , options : FinalRequestOptions ) -> FinalRequestOptions :
181+ if not self .client_adapter or not self .client_adapter .is_vertex_ai ():
182+ return options
183+
184+ headers = options .headers or {}
185+ has_auth = headers .get ("Authorization" ) or headers .get ("x-goog-api-key" ) # pytype: disable=attribute-error
186+ if has_auth :
187+ return options
188+
189+ adapted_headers = self .client_adapter .get_auth_headers ()
190+ if adapted_headers :
191+ options .headers = {
192+ ** adapted_headers ,
193+ ** headers
194+ }
195+ return options
196+
169197 def copy (
170198 self ,
171199 * ,
@@ -179,6 +207,7 @@ def copy(
179207 set_default_headers : Mapping [str , str ] | None = None ,
180208 default_query : Mapping [str , object ] | None = None ,
181209 set_default_query : Mapping [str , object ] | None = None ,
210+ client_adapter : GeminiNextGenAPIClientAdapter | None = None ,
182211 _extra_kwargs : Mapping [str , Any ] = {},
183212 ) -> Self :
184213 """
@@ -212,6 +241,7 @@ def copy(
212241 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
213242 default_headers = headers ,
214243 default_query = params ,
244+ client_adapter = self .client_adapter or client_adapter ,
215245 ** _extra_kwargs ,
216246 )
217247
@@ -260,6 +290,7 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient):
260290 # client options
261291 api_key : str | None
262292 api_version : str
293+ client_adapter : AsyncGeminiNextGenAPIClientAdapter | None
263294
264295 def __init__ (
265296 self ,
@@ -275,6 +306,7 @@ def __init__(
275306 # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
276307 # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
277308 http_client : httpx .AsyncClient | None = None ,
309+ client_adapter : AsyncGeminiNextGenAPIClientAdapter | None = None ,
278310 # Enable or disable schema validation for data returned by the API.
279311 # When enabled an error APIResponseValidationError is raised
280312 # if the API responds with invalid data for the expected schema.
@@ -302,6 +334,8 @@ def __init__(
302334 if base_url is None :
303335 base_url = f"https://generativelanguage.googleapis.com"
304336
337+ self .client_adapter = client_adapter
338+
305339 super ().__init__ (
306340 version = __version__ ,
307341 base_url = base_url ,
@@ -353,12 +387,34 @@ def default_headers(self) -> dict[str, str | Omit]:
353387
354388 @override
355389 def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
356- if headers .get ("x-goog-api-key" ) or isinstance (custom_headers .get ("x-goog-api-key" ), Omit ):
390+ if headers .get ("Authorization" ) or custom_headers .get ("Authorization" ) or isinstance (custom_headers .get ("Authorization" ), Omit ):
391+ return
392+ if self .api_key and headers .get ("x-goog-api-key" ):
393+ return
394+ if custom_headers .get ("x-goog-api-key" ) or isinstance (custom_headers .get ("x-goog-api-key" ), Omit ):
357395 return
358396
359397 raise TypeError (
360398 '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
361399 )
400+
401+ @override
402+ async def _prepare_options (self , options : FinalRequestOptions ) -> FinalRequestOptions :
403+ if not self .client_adapter or not self .client_adapter .is_vertex_ai ():
404+ return options
405+
406+ headers = options .headers or {}
407+ has_auth = headers .get ("Authorization" ) or headers .get ("x-goog-api-key" ) # pytype: disable=attribute-error
408+ if has_auth :
409+ return options
410+
411+ adapted_headers = await self .client_adapter .async_get_auth_headers ()
412+ if adapted_headers :
413+ options .headers = {
414+ ** adapted_headers ,
415+ ** headers
416+ }
417+ return options
362418
363419 def copy (
364420 self ,
@@ -373,6 +429,7 @@ def copy(
373429 set_default_headers : Mapping [str , str ] | None = None ,
374430 default_query : Mapping [str , object ] | None = None ,
375431 set_default_query : Mapping [str , object ] | None = None ,
432+ client_adapter : AsyncGeminiNextGenAPIClientAdapter | None = None ,
376433 _extra_kwargs : Mapping [str , Any ] = {},
377434 ) -> Self :
378435 """
@@ -406,6 +463,7 @@ def copy(
406463 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
407464 default_headers = headers ,
408465 default_query = params ,
466+ client_adapter = self .client_adapter or client_adapter ,
409467 ** _extra_kwargs ,
410468 )
411469
0 commit comments