29
29
from copy import deepcopy
30
30
from dataclasses import dataclass , field
31
31
from datetime import datetime
32
- from typing import Annotated , Any , Literal , Union
32
+ from typing import Annotated , Any , Literal , Protocol , Union
33
33
34
34
import pydantic_core
35
35
from httpx import AsyncClient as AsyncHTTPClient , Response as HTTPResponse
@@ -77,17 +77,17 @@ class GeminiModel(Model):
77
77
"""
78
78
79
79
model_name : GeminiModelName
80
- api_key : str
80
+ auth : AuthProtocol
81
81
http_client : AsyncHTTPClient
82
- url_template : str
82
+ url : str
83
83
84
84
def __init__ (
85
85
self ,
86
86
model_name : GeminiModelName ,
87
87
* ,
88
88
api_key : str | None = None ,
89
89
http_client : AsyncHTTPClient | None = None ,
90
- url_template : str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:{function} ' ,
90
+ url_template : str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:' ,
91
91
):
92
92
"""Initialize a Gemini model.
93
93
@@ -97,62 +97,94 @@ def __init__(
97
97
will be used if available.
98
98
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
99
99
url_template: The URL template to use for making requests, you shouldn't need to change this,
100
- docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request).
100
+ docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
101
+ `model` is substituted with the model name, and `function` is added to the end of the URL.
101
102
"""
102
103
self .model_name = model_name
103
104
if api_key is None :
104
105
if env_api_key := os .getenv ('GEMINI_API_KEY' ):
105
106
api_key = env_api_key
106
107
else :
107
108
raise exceptions .UserError ('API key must be provided or set in the GEMINI_API_KEY environment variable' )
108
- self .api_key = api_key
109
+ self .auth = ApiKeyAuth ( api_key )
109
110
self .http_client = http_client or cached_async_http_client ()
110
- self .url_template = url_template
111
+ self .url = url_template . format ( model = model_name )
111
112
112
- def agent_model (
113
+ async def agent_model (
113
114
self ,
114
115
retrievers : Mapping [str , AbstractToolDefinition ],
115
116
allow_text_result : bool ,
116
117
result_tools : Sequence [AbstractToolDefinition ] | None ,
117
118
) -> GeminiAgentModel :
118
- check_allow_model_requests ()
119
- tools = [_function_from_abstract_tool (t ) for t in retrievers .values ()]
120
- if result_tools is not None :
121
- tools += [_function_from_abstract_tool (t ) for t in result_tools ]
122
-
123
- if allow_text_result :
124
- tool_config = None
125
- else :
126
- tool_config = _tool_config ([t ['name' ] for t in tools ])
127
-
128
119
return GeminiAgentModel (
129
120
http_client = self .http_client ,
130
121
model_name = self .model_name ,
131
- api_key = self .api_key ,
132
- tools = _GeminiTools (function_declarations = tools ) if tools else None ,
133
- tool_config = tool_config ,
134
- url_template = self .url_template ,
122
+ auth = self .auth ,
123
+ url = self .url ,
124
+ retrievers = retrievers ,
125
+ allow_text_result = allow_text_result ,
126
+ result_tools = result_tools ,
135
127
)
136
128
137
129
def name (self ) -> str :
138
130
return self .model_name
139
131
140
132
133
+ class AuthProtocol (Protocol ):
134
+ async def headers (self ) -> dict [str , str ]: ...
135
+
136
+
141
137
@dataclass
138
+ class ApiKeyAuth :
139
+ api_key : str
140
+
141
+ async def headers (self ) -> dict [str , str ]:
142
+ # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
143
+ return {'X-Goog-Api-Key' : self .api_key }
144
+
145
+
146
+ @dataclass (init = False )
142
147
class GeminiAgentModel (AgentModel ):
143
148
"""Implementation of `AgentModel` for Gemini models."""
144
149
145
150
http_client : AsyncHTTPClient
146
151
model_name : GeminiModelName
147
- api_key : str
152
+ auth : AuthProtocol
148
153
tools : _GeminiTools | None
149
154
tool_config : _GeminiToolConfig | None
150
- url_template : str
155
+ url : str
156
+
157
+ def __init__ (
158
+ self ,
159
+ http_client : AsyncHTTPClient ,
160
+ model_name : GeminiModelName ,
161
+ auth : AuthProtocol ,
162
+ url : str ,
163
+ retrievers : Mapping [str , AbstractToolDefinition ],
164
+ allow_text_result : bool ,
165
+ result_tools : Sequence [AbstractToolDefinition ] | None ,
166
+ ):
167
+ check_allow_model_requests ()
168
+ tools = [_function_from_abstract_tool (t ) for t in retrievers .values ()]
169
+ if result_tools is not None :
170
+ tools += [_function_from_abstract_tool (t ) for t in result_tools ]
171
+
172
+ if allow_text_result :
173
+ tool_config = None
174
+ else :
175
+ tool_config = _tool_config ([t ['name' ] for t in tools ])
176
+
177
+ self .http_client = http_client
178
+ self .model_name = model_name
179
+ self .auth = auth
180
+ self .tools = _GeminiTools (function_declarations = tools ) if tools else None
181
+ self .tool_config = tool_config
182
+ self .url = url
151
183
152
184
async def request (self , messages : list [Message ]) -> tuple [ModelAnyResponse , result .Cost ]:
153
185
async with self ._make_request (messages , False ) as http_response :
154
186
response = _gemini_response_ta .validate_json (await http_response .aread ())
155
- return self ._process_response (response ), _metadata_as_cost (response [ 'usage_metadata' ] )
187
+ return self ._process_response (response ), _metadata_as_cost (response )
156
188
157
189
@asynccontextmanager
158
190
async def request_stream (self , messages : list [Message ]) -> AsyncIterator [EitherStreamedResponse ]:
@@ -178,16 +210,15 @@ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncI
178
210
if self .tool_config is not None :
179
211
request_data ['tool_config' ] = self .tool_config
180
212
181
- request_json = _gemini_request_ta . dump_json ( request_data , by_alias = True )
182
- # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
213
+ url = self . url + ( 'streamGenerateContent' if streamed else 'generateContent' )
214
+
183
215
headers = {
184
- 'X-Goog-Api-Key' : self .api_key ,
185
216
'Content-Type' : 'application/json' ,
186
217
'User-Agent' : get_user_agent (),
218
+ ** await self .auth .headers (),
187
219
}
188
- url = self .url_template .format (
189
- model = self .model_name , function = 'streamGenerateContent' if streamed else 'generateContent'
190
- )
220
+
221
+ request_json = _gemini_request_ta .dump_json (request_data , by_alias = True )
191
222
192
223
async with self .http_client .stream ('POST' , url , content = request_json , headers = headers ) as r :
193
224
if r .status_code != 200 :
@@ -283,7 +314,7 @@ def get(self, *, final: bool = False) -> Iterable[str]:
283
314
new_items , experimental_allow_partial = 'trailing-strings'
284
315
)
285
316
for r in new_responses :
286
- self ._cost += _metadata_as_cost (r [ 'usage_metadata' ] )
317
+ self ._cost += _metadata_as_cost (r )
287
318
parts = r ['candidates' ][0 ]['content' ]['parts' ]
288
319
if _all_text_parts (parts ):
289
320
for part in parts :
@@ -329,7 +360,7 @@ def get(self, *, final: bool = False) -> ModelStructuredResponse:
329
360
combined_parts : list [_GeminiFunctionCallPart ] = []
330
361
self ._cost = result .Cost ()
331
362
for r in responses :
332
- self ._cost += _metadata_as_cost (r [ 'usage_metadata' ] )
363
+ self ._cost += _metadata_as_cost (r )
333
364
candidate = r ['candidates' ][0 ]
334
365
parts = candidate ['content' ]['parts' ]
335
366
if _all_function_call_parts (parts ):
@@ -521,10 +552,12 @@ class _GeminiResponse(TypedDict):
521
552
"""Schema for the response from the Gemini API.
522
553
523
554
See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>
555
+ and <https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerateContentResponse>
524
556
"""
525
557
526
558
candidates : list [_GeminiCandidates ]
527
- usage_metadata : Annotated [_GeminiUsageMetaData , Field (alias = 'usageMetadata' )]
559
+ # usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
560
+ usage_metadata : NotRequired [Annotated [_GeminiUsageMetaData , Field (alias = 'usageMetadata' )]]
528
561
prompt_feedback : NotRequired [Annotated [_GeminiPromptFeedback , Field (alias = 'promptFeedback' )]]
529
562
530
563
@@ -582,7 +615,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
582
615
cached_content_token_count : NotRequired [Annotated [int , Field (alias = 'cachedContentTokenCount' )]]
583
616
584
617
585
- def _metadata_as_cost (metadata : _GeminiUsageMetaData ) -> result .Cost :
618
+ def _metadata_as_cost (response : _GeminiResponse ) -> result .Cost :
619
+ metadata = response .get ('usage_metadata' )
620
+ if metadata is None :
621
+ return result .Cost ()
586
622
details : dict [str , int ] = {}
587
623
if cached_content_token_count := metadata .get ('cached_content_token_count' ):
588
624
details ['cached_content_token_count' ] = cached_content_token_count
0 commit comments