40
40
from ..providers import Provider , infer_provider
41
41
from ..settings import ModelSettings
42
42
from ..tools import ToolDefinition
43
- from . import (
44
- Model ,
45
- ModelRequestParameters ,
46
- StreamedResponse ,
47
- check_allow_model_requests ,
48
- download_item ,
49
- get_user_agent ,
50
- )
43
+ from . import Model , ModelRequestParameters , StreamedResponse , check_allow_model_requests , download_item , get_user_agent
51
44
52
45
LatestGeminiModelNames = Literal [
53
46
'gemini-2.0-flash' ,
@@ -108,10 +101,9 @@ class GeminiModel(Model):
108
101
client : httpx .AsyncClient = field (repr = False )
109
102
110
103
_model_name : GeminiModelName = field (repr = False )
111
- _provider : Literal [ 'google-gla' , 'google-vertex' ] | Provider [httpx .AsyncClient ] | None = field (repr = False )
104
+ _provider : Provider [httpx .AsyncClient ] = field (repr = False )
112
105
_auth : AuthProtocol | None = field (repr = False )
113
106
_url : str | None = field (repr = False )
114
- _system : str = field (default = 'gemini' , repr = False )
115
107
116
108
def __init__ (
117
109
self ,
@@ -132,11 +124,10 @@ def __init__(
132
124
settings: Default model settings for this model instance.
133
125
"""
134
126
self ._model_name = model_name
135
- self ._provider = provider
136
127
137
128
if isinstance (provider , str ):
138
129
provider = infer_provider (provider )
139
- self ._system = provider . name
130
+ self ._provider = provider
140
131
self .client = provider .client
141
132
self ._url = str (self .client .base_url )
142
133
@@ -147,6 +138,16 @@ def base_url(self) -> str:
147
138
assert self ._url is not None , 'URL not initialized' # pragma: no cover
148
139
return self ._url # pragma: no cover
149
140
141
+ @property
142
+ def model_name (self ) -> GeminiModelName :
143
+ """The model name."""
144
+ return self ._model_name
145
+
146
+ @property
147
+ def system (self ) -> str :
148
+ """The model provider."""
149
+ return self ._provider .name
150
+
150
151
async def request (
151
152
self ,
152
153
messages : list [ModelMessage ],
@@ -175,16 +176,6 @@ async def request_stream(
175
176
) as http_response :
176
177
yield await self ._process_streamed_response (http_response , model_request_parameters )
177
178
178
- @property
179
- def model_name (self ) -> GeminiModelName :
180
- """The model name."""
181
- return self ._model_name
182
-
183
- @property
184
- def system (self ) -> str :
185
- """The system / model provider."""
186
- return self ._system
187
-
188
179
def _get_tools (self , model_request_parameters : ModelRequestParameters ) -> _GeminiTools | None :
189
180
tools = [_function_from_abstract_tool (t ) for t in model_request_parameters .tool_defs .values ()]
190
181
return _GeminiTools (function_declarations = tools ) if tools else None
@@ -237,7 +228,7 @@ async def _make_request(
237
228
request_data ['safetySettings' ] = gemini_safety_settings
238
229
239
230
if gemini_labels := model_settings .get ('gemini_labels' ):
240
- if self ._system == 'google-vertex' :
231
+ if self ._provider . name == 'google-vertex' :
241
232
request_data ['labels' ] = gemini_labels # pragma: lax no cover
242
233
243
234
headers = {'Content-Type' : 'application/json' , 'User-Agent' : get_user_agent ()}
0 commit comments