Skip to content

Commit 10eb5b8

Browse files
authored
Use _provider.name instead of _system (#2596)
1 parent 72e6037 commit 10eb5b8

File tree

15 files changed

+138
-147
lines changed

15 files changed

+138
-147
lines changed

CLAUDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ from typing_extensions import deprecated
134134

135135
class NewClass: ... # This class was renamed from OldClass.
136136

137-
@deprecated("Use `NewClass` instead")
137+
@deprecated("Use `NewClass` instead.")
138138
class OldClass(NewClass): ...
139139
```
140140

@@ -143,7 +143,7 @@ deprecation warning:
143143

144144
```python
145145
def test_old_class_is_deprecated():
146-
with pytest.warns(DeprecationWarning, match="Use `NewClass` instead"):
146+
with pytest.warns(DeprecationWarning, match="Use `NewClass` instead."):
147147
OldClass()
148148
```
149149

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def profile(self) -> ModelProfile:
480480
@property
481481
@abstractmethod
482482
def system(self) -> str:
483-
"""The system / model provider, ex: openai.
483+
"""The model provider, ex: openai.
484484
485485
Use to populate the `gen_ai.system` OpenTelemetry semantic convention attribute,
486486
so should use well-known values listed in

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class AnthropicModel(Model):
137137
client: AsyncAnthropic = field(repr=False)
138138

139139
_model_name: AnthropicModelName = field(repr=False)
140-
_system: str = field(default='anthropic', repr=False)
140+
_provider: Provider[AsyncAnthropic] = field(repr=False)
141141

142142
def __init__(
143143
self,
@@ -161,6 +161,7 @@ def __init__(
161161

162162
if isinstance(provider, str):
163163
provider = infer_provider(provider)
164+
self._provider = provider
164165
self.client = provider.client
165166

166167
super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -169,6 +170,16 @@ def __init__(
169170
def base_url(self) -> str:
170171
return str(self.client.base_url)
171172

173+
@property
174+
def model_name(self) -> AnthropicModelName:
175+
"""The model name."""
176+
return self._model_name
177+
178+
@property
179+
def system(self) -> str:
180+
"""The model provider."""
181+
return self._provider.name
182+
172183
async def request(
173184
self,
174185
messages: list[ModelMessage],
@@ -197,16 +208,6 @@ async def request_stream(
197208
async with response:
198209
yield await self._process_streamed_response(response, model_request_parameters)
199210

200-
@property
201-
def model_name(self) -> AnthropicModelName:
202-
"""The model name."""
203-
return self._model_name
204-
205-
@property
206-
def system(self) -> str:
207-
"""The system / model provider."""
208-
return self._system
209-
210211
@overload
211212
async def _messages_create(
212213
self,

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,7 @@ class BedrockConverseModel(Model):
190190
client: BedrockRuntimeClient
191191

192192
_model_name: BedrockModelName = field(repr=False)
193-
_system: str = field(default='bedrock', repr=False)
194-
195-
@property
196-
def model_name(self) -> str:
197-
"""The model name."""
198-
return self._model_name
199-
200-
@property
201-
def system(self) -> str:
202-
"""The system / model provider, ex: openai."""
203-
return self._system
193+
_provider: Provider[BaseClient] = field(repr=False)
204194

205195
def __init__(
206196
self,
@@ -226,10 +216,25 @@ def __init__(
226216

227217
if isinstance(provider, str):
228218
provider = infer_provider(provider)
219+
self._provider = provider
229220
self.client = cast('BedrockRuntimeClient', provider.client)
230221

231222
super().__init__(settings=settings, profile=profile or provider.model_profile)
232223

224+
@property
225+
def base_url(self) -> str:
226+
return str(self.client.meta.endpoint_url)
227+
228+
@property
229+
def model_name(self) -> str:
230+
"""The model name."""
231+
return self._model_name
232+
233+
@property
234+
def system(self) -> str:
235+
"""The model provider."""
236+
return self._provider.name
237+
233238
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
234239
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
235240

@@ -245,10 +250,6 @@ def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
245250

246251
return {'toolSpec': tool_spec}
247252

248-
@property
249-
def base_url(self) -> str:
250-
return str(self.client.meta.endpoint_url)
251-
252253
async def request(
253254
self,
254255
messages: list[ModelMessage],

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@
3030
from ..providers import Provider, infer_provider
3131
from ..settings import ModelSettings
3232
from ..tools import ToolDefinition
33-
from . import (
34-
Model,
35-
ModelRequestParameters,
36-
check_allow_model_requests,
37-
)
33+
from . import Model, ModelRequestParameters, check_allow_model_requests
3834

3935
try:
4036
from cohere import (
@@ -106,7 +102,7 @@ class CohereModel(Model):
106102
client: AsyncClientV2 = field(repr=False)
107103

108104
_model_name: CohereModelName = field(repr=False)
109-
_system: str = field(default='cohere', repr=False)
105+
_provider: Provider[AsyncClientV2] = field(repr=False)
110106

111107
def __init__(
112108
self,
@@ -131,6 +127,7 @@ def __init__(
131127

132128
if isinstance(provider, str):
133129
provider = infer_provider(provider)
130+
self._provider = provider
134131
self.client = provider.client
135132

136133
super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -140,6 +137,16 @@ def base_url(self) -> str:
140137
client_wrapper = self.client._client_wrapper # type: ignore
141138
return str(client_wrapper.get_base_url())
142139

140+
@property
141+
def model_name(self) -> CohereModelName:
142+
"""The model name."""
143+
return self._model_name
144+
145+
@property
146+
def system(self) -> str:
147+
"""The model provider."""
148+
return self._provider.name
149+
143150
async def request(
144151
self,
145152
messages: list[ModelMessage],
@@ -151,16 +158,6 @@ async def request(
151158
model_response = self._process_response(response)
152159
return model_response
153160

154-
@property
155-
def model_name(self) -> CohereModelName:
156-
"""The model name."""
157-
return self._model_name
158-
159-
@property
160-
def system(self) -> str:
161-
"""The system / model provider."""
162-
return self._system
163-
164161
async def _chat(
165162
self,
166163
messages: list[ModelMessage],

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ def __init__(
5252
else:
5353
self._fallback_on = fallback_on
5454

55+
@property
56+
def model_name(self) -> str:
57+
"""The model name."""
58+
return f'fallback:{",".join(model.model_name for model in self.models)}'
59+
60+
@property
61+
def system(self) -> str:
62+
return f'fallback:{",".join(model.system for model in self.models)}'
63+
64+
@property
65+
def base_url(self) -> str | None:
66+
return self.models[0].base_url
67+
5568
async def request(
5669
self,
5770
messages: list[ModelMessage],
@@ -121,19 +134,6 @@ def _set_span_attributes(self, model: Model):
121134
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
122135
span.set_attributes(InstrumentedModel.model_attributes(model))
123136

124-
@property
125-
def model_name(self) -> str:
126-
"""The model name."""
127-
return f'fallback:{",".join(model.model_name for model in self.models)}'
128-
129-
@property
130-
def system(self) -> str:
131-
return f'fallback:{",".join(model.system for model in self.models)}'
132-
133-
@property
134-
def base_url(self) -> str | None:
135-
return self.models[0].base_url
136-
137137

138138
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
139139
"""Create a default fallback condition for the given exceptions."""

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,7 @@
4040
from ..providers import Provider, infer_provider
4141
from ..settings import ModelSettings
4242
from ..tools import ToolDefinition
43-
from . import (
44-
Model,
45-
ModelRequestParameters,
46-
StreamedResponse,
47-
check_allow_model_requests,
48-
download_item,
49-
get_user_agent,
50-
)
43+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent
5144

5245
LatestGeminiModelNames = Literal[
5346
'gemini-2.0-flash',
@@ -108,10 +101,9 @@ class GeminiModel(Model):
108101
client: httpx.AsyncClient = field(repr=False)
109102

110103
_model_name: GeminiModelName = field(repr=False)
111-
_provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
104+
_provider: Provider[httpx.AsyncClient] = field(repr=False)
112105
_auth: AuthProtocol | None = field(repr=False)
113106
_url: str | None = field(repr=False)
114-
_system: str = field(default='gemini', repr=False)
115107

116108
def __init__(
117109
self,
@@ -132,11 +124,10 @@ def __init__(
132124
settings: Default model settings for this model instance.
133125
"""
134126
self._model_name = model_name
135-
self._provider = provider
136127

137128
if isinstance(provider, str):
138129
provider = infer_provider(provider)
139-
self._system = provider.name
130+
self._provider = provider
140131
self.client = provider.client
141132
self._url = str(self.client.base_url)
142133

@@ -147,6 +138,16 @@ def base_url(self) -> str:
147138
assert self._url is not None, 'URL not initialized' # pragma: no cover
148139
return self._url # pragma: no cover
149140

141+
@property
142+
def model_name(self) -> GeminiModelName:
143+
"""The model name."""
144+
return self._model_name
145+
146+
@property
147+
def system(self) -> str:
148+
"""The model provider."""
149+
return self._provider.name
150+
150151
async def request(
151152
self,
152153
messages: list[ModelMessage],
@@ -175,16 +176,6 @@ async def request_stream(
175176
) as http_response:
176177
yield await self._process_streamed_response(http_response, model_request_parameters)
177178

178-
@property
179-
def model_name(self) -> GeminiModelName:
180-
"""The model name."""
181-
return self._model_name
182-
183-
@property
184-
def system(self) -> str:
185-
"""The system / model provider."""
186-
return self._system
187-
188179
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _GeminiTools | None:
189180
tools = [_function_from_abstract_tool(t) for t in model_request_parameters.tool_defs.values()]
190181
return _GeminiTools(function_declarations=tools) if tools else None
@@ -237,7 +228,7 @@ async def _make_request(
237228
request_data['safetySettings'] = gemini_safety_settings
238229

239230
if gemini_labels := model_settings.get('gemini_labels'):
240-
if self._system == 'google-vertex':
231+
if self._provider.name == 'google-vertex':
241232
request_data['labels'] = gemini_labels # pragma: lax no cover
242233

243234
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ class GoogleModel(Model):
144144
_model_name: GoogleModelName = field(repr=False)
145145
_provider: Provider[Client] = field(repr=False)
146146
_url: str | None = field(repr=False)
147-
_system: str = field(default='google', repr=False)
148147

149148
def __init__(
150149
self,
@@ -168,9 +167,7 @@ def __init__(
168167

169168
if isinstance(provider, str):
170169
provider = GoogleProvider(vertexai=provider == 'google-vertex')
171-
172170
self._provider = provider
173-
self._system = provider.name
174171
self.client = provider.client
175172

176173
super().__init__(settings=settings, profile=profile or provider.model_profile)
@@ -179,6 +176,16 @@ def __init__(
179176
def base_url(self) -> str:
180177
return self._provider.base_url
181178

179+
@property
180+
def model_name(self) -> GoogleModelName:
181+
"""The model name."""
182+
return self._model_name
183+
184+
@property
185+
def system(self) -> str:
186+
"""The model provider."""
187+
return self._provider.name
188+
182189
async def request(
183190
self,
184191
messages: list[ModelMessage],
@@ -209,7 +216,7 @@ async def count_tokens(
209216
config = CountTokensConfigDict(
210217
http_options=generation_config.get('http_options'),
211218
)
212-
if self.system != 'google-gla':
219+
if self._provider.name != 'google-gla':
213220
# The fields are not supported by the Gemini API per https://github.com/googleapis/python-genai/blob/7e4ec284dc6e521949626f3ed54028163ef9121d/google/genai/models.py#L1195-L1214
214221
config.update(
215222
system_instruction=generation_config.get('system_instruction'),
@@ -255,16 +262,6 @@ async def request_stream(
255262
response = await self._generate_content(messages, True, model_settings, model_request_parameters)
256263
yield await self._process_streamed_response(response, model_request_parameters) # type: ignore
257264

258-
@property
259-
def model_name(self) -> GoogleModelName:
260-
"""The model name."""
261-
return self._model_name
262-
263-
@property
264-
def system(self) -> str:
265-
"""The system / model provider."""
266-
return self._system
267-
268265
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolDict] | None:
269266
tools: list[ToolDict] = [
270267
ToolDict(function_declarations=[_function_declaration_from_tool(t)])

0 commit comments

Comments
 (0)