Skip to content

Commit 59cacf8

Browse files
Using model_name and system model properties (#865)
1 parent 35b6470 commit 59cacf8

File tree

18 files changed

+181
-120
lines changed

18 files changed

+181
-120
lines changed

docs/models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ pip/uv-add 'pydantic-ai-slim[mistral]'
377377

378378
To use [Mistral](https://mistral.ai) through their API, go to [console.mistral.ai/api-keys/](https://console.mistral.ai/api-keys/) and follow your nose until you find the place to generate an API key.
379379

380-
[`NamedMistralModels`][pydantic_ai.models.mistral.NamedMistralModels] contains a list of the most popular Mistral models.
380+
[`MistralModelName`][pydantic_ai.models.mistral.MistralModelName] contains a list of the most popular Mistral models.
381381

382382
### Environment variable
383383

@@ -434,7 +434,7 @@ pip/uv-add 'pydantic-ai-slim[cohere]'
434434

435435
To use [Cohere](https://cohere.com/) through their API, go to [dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys) and follow your nose until you find the place to generate an API key.
436436

437-
[`NamedCohereModels`][pydantic_ai.models.cohere.NamedCohereModels] contains a list of the most popular Cohere models.
437+
[`CohereModelName`][pydantic_ai.models.cohere.CohereModelName] contains a list of the most popular Cohere models.
438438

439439
### Environment variable
440440

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ async def main():
309309
'{agent_name} run {prompt=}',
310310
prompt=user_prompt,
311311
agent=self,
312-
model_name=model_used.name() if model_used else 'no-model',
312+
model_name=model_used.model_name if model_used else 'no-model',
313313
agent_name=self.name or 'agent',
314314
) as run_span:
315315
# Build the deps object for the graph
@@ -554,7 +554,7 @@ async def main():
554554
'{agent_name} run stream {prompt=}',
555555
prompt=user_prompt,
556556
agent=self,
557-
model_name=model_used.name(),
557+
model_name=model_used.model_name if model_used else 'no-model',
558558
agent_name=self.name or 'agent',
559559
) as run_span:
560560
# Build the deps object for the graph

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,8 @@ class ModelRequestParameters:
173173
class Model(ABC):
174174
"""Abstract class for a model."""
175175

176-
@abstractmethod
177-
def name(self) -> str:
178-
raise NotImplementedError()
176+
_model_name: str
177+
_system: str | None
179178

180179
@abstractmethod
181180
async def request(
@@ -201,6 +200,16 @@ async def request_stream(
201200
# noinspection PyUnreachableCode
202201
yield # pragma: no cover
203202

203+
@property
204+
def model_name(self) -> str:
205+
"""The model name."""
206+
return self._model_name
207+
208+
@property
209+
def system(self) -> str | None:
210+
"""The system / model provider, ex: openai."""
211+
return self._system
212+
204213

205214
@dataclass
206215
class StreamedResponse(ABC):
@@ -311,26 +320,26 @@ def infer_model(model: Model | KnownModelName) -> Model:
311320
elif model.startswith('google-gla'):
312321
from .gemini import GeminiModel
313322

314-
return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
323+
return GeminiModel(model[11:])
315324
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
316325
elif model.startswith('gemini'):
317326
from .gemini import GeminiModel
318327

319328
# noinspection PyTypeChecker
320-
return GeminiModel(model) # pyright: ignore[reportArgumentType]
329+
return GeminiModel(model)
321330
elif model.startswith('groq:'):
322331
from .groq import GroqModel
323332

324-
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
333+
return GroqModel(model[5:])
325334
elif model.startswith('google-vertex'):
326335
from .vertexai import VertexAIModel
327336

328-
return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
337+
return VertexAIModel(model[14:])
329338
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
330339
elif model.startswith('vertexai:'):
331340
from .vertexai import VertexAIModel
332341

333-
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
342+
return VertexAIModel(model[9:])
334343
elif model.startswith('mistral:'):
335344
from .mistral import MistralModel
336345

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@
6868
'claude-3-5-sonnet-latest',
6969
'claude-3-opus-latest',
7070
]
71-
"""Latest named Anthropic models."""
71+
"""Latest Anthropic models."""
7272

7373
AnthropicModelName = Union[str, LatestAnthropicModelNames]
7474
"""Possible Anthropic model names.
7575
7676
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
7777
allow any name in the type hints.
78-
Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
78+
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
7979
"""
8080

8181

@@ -101,9 +101,11 @@ class AnthropicModel(Model):
101101
We anticipate adding support for streaming responses in a near-term future release.
102102
"""
103103

104-
model_name: AnthropicModelName
105104
client: AsyncAnthropic = field(repr=False)
106105

106+
_model_name: AnthropicModelName = field(repr=False)
107+
_system: str | None = field(default='anthropic', repr=False)
108+
107109
def __init__(
108110
self,
109111
model_name: AnthropicModelName,
@@ -124,7 +126,7 @@ def __init__(
124126
client to use, if provided, `api_key` and `http_client` must be `None`.
125127
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
126128
"""
127-
self.model_name = model_name
129+
self._model_name = model_name
128130
if anthropic_client is not None:
129131
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
130132
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
@@ -134,9 +136,6 @@ def __init__(
134136
else:
135137
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
136138

137-
def name(self) -> str:
138-
return f'anthropic:{self.model_name}'
139-
140139
async def request(
141140
self,
142141
messages: list[ModelMessage],
@@ -211,7 +210,7 @@ async def _messages_create(
211210
max_tokens=model_settings.get('max_tokens', 1024),
212211
system=system_prompt or NOT_GIVEN,
213212
messages=anthropic_messages,
214-
model=self.model_name,
213+
model=self._model_name,
215214
tools=tools or NOT_GIVEN,
216215
tool_choice=tool_choice or NOT_GIVEN,
217216
stream=stream,
@@ -237,7 +236,7 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
237236
)
238237
)
239238

240-
return ModelResponse(items, model_name=self.model_name)
239+
return ModelResponse(items, model_name=self._model_name)
241240

242241
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
243242
peekable_response = _utils.PeekableAsyncStream(response)
@@ -247,7 +246,9 @@ async def _process_streamed_response(self, response: AsyncStream[RawMessageStrea
247246

248247
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
249248
timestamp = datetime.now(tz=timezone.utc)
250-
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
249+
return AnthropicStreamedResponse(
250+
_model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
251+
)
251252

252253
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
253254
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
5353
) from _import_error
5454

55-
NamedCohereModels = Literal[
55+
LatestCohereModelNames = Literal[
5656
'c4ai-aya-expanse-32b',
5757
'c4ai-aya-expanse-8b',
5858
'command',
@@ -67,9 +67,15 @@
6767
'command-r-plus-08-2024',
6868
'command-r7b-12-2024',
6969
]
70-
"""Latest / most popular named Cohere models."""
70+
"""Latest Cohere models."""
7171

72-
CohereModelName = Union[NamedCohereModels, str]
72+
CohereModelName = Union[str, LatestCohereModelNames]
73+
"""Possible Cohere model names.
74+
75+
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
76+
allow any name in the type hints.
77+
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
78+
"""
7379

7480

7581
class CohereModelSettings(ModelSettings):
@@ -88,9 +94,11 @@ class CohereModel(Model):
8894
Apart from `__init__`, all methods are private or match those of the base class.
8995
"""
9096

91-
model_name: CohereModelName
9297
client: AsyncClientV2 = field(repr=False)
9398

99+
_model_name: CohereModelName = field(repr=False)
100+
_system: str | None = field(default='cohere', repr=False)
101+
94102
def __init__(
95103
self,
96104
model_name: CohereModelName,
@@ -110,17 +118,14 @@ def __init__(
110118
`api_key` and `http_client` must be `None`.
111119
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
112120
"""
113-
self.model_name: CohereModelName = model_name
121+
self._model_name: CohereModelName = model_name
114122
if cohere_client is not None:
115123
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
116124
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
117125
self.client = cohere_client
118126
else:
119127
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
120128

121-
def name(self) -> str:
122-
return f'cohere:{self.model_name}'
123-
124129
async def request(
125130
self,
126131
messages: list[ModelMessage],
@@ -140,7 +145,7 @@ async def _chat(
140145
tools = self._get_tools(model_request_parameters)
141146
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
142147
return await self.client.chat(
143-
model=self.model_name,
148+
model=self._model_name,
144149
messages=cohere_messages,
145150
tools=tools or OMIT,
146151
max_tokens=model_settings.get('max_tokens', OMIT),
@@ -168,7 +173,7 @@ def _process_response(self, response: ChatResponse) -> ModelResponse:
168173
tool_call_id=c.id,
169174
)
170175
)
171-
return ModelResponse(parts=parts, model_name=self.model_name)
176+
return ModelResponse(parts=parts, model_name=self._model_name)
172177

173178
def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
174179
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class FunctionModel(Model):
4040
function: FunctionDef | None = None
4141
stream_function: StreamFunctionDef | None = None
4242

43+
_model_name: str = field(repr=False)
44+
_system: str | None = field(default=None, repr=False)
45+
4346
@overload
4447
def __init__(self, function: FunctionDef) -> None: ...
4548

@@ -63,10 +66,9 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
6366
self.function = function
6467
self.stream_function = stream_function
6568

66-
def name(self) -> str:
6769
function_name = self.function.__name__ if self.function is not None else ''
6870
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
69-
return f'function:{function_name}:{stream_function_name}'
71+
self._model_name = f'function:{function_name}:{stream_function_name}'
7072

7173
async def request(
7274
self,
@@ -82,15 +84,14 @@ async def request(
8284
)
8385

8486
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
85-
model_name = f'function:{self.function.__name__}'
8687

8788
if inspect.iscoroutinefunction(self.function):
8889
response = await self.function(messages, agent_info)
8990
else:
9091
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
9192
assert isinstance(response_, ModelResponse), response_
9293
response = response_
93-
response.model_name = model_name
94+
response.model_name = f'function:{self.function.__name__}'
9495
# TODO is `messages` right here? Should it just be new messages?
9596
return response, _estimate_usage(chain(messages, [response]))
9697

@@ -111,15 +112,14 @@ async def request_stream(
111112
assert (
112113
self.stream_function is not None
113114
), 'FunctionModel must receive a `stream_function` to support streamed requests'
114-
model_name = f'function:{self.stream_function.__name__}'
115115

116116
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
117117

118118
first = await response_stream.peek()
119119
if isinstance(first, _utils.Unset):
120120
raise ValueError('Stream function must return at least one item')
121121

122-
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
122+
yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
123123

124124

125125
@dataclass(frozen=True)

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
get_user_agent,
4040
)
4141

42-
GeminiModelName = Literal[
42+
LatestGeminiModelNames = Literal[
4343
'gemini-1.5-flash',
4444
'gemini-1.5-flash-8b',
4545
'gemini-1.5-pro',
@@ -48,8 +48,13 @@
4848
'gemini-2.0-flash-thinking-exp-01-21',
4949
'gemini-exp-1206',
5050
]
51-
"""Named Gemini models.
51+
"""Latest Gemini models."""
5252

53+
GeminiModelName = Union[str, LatestGeminiModelNames]
54+
"""Possible Gemini model names.
55+
56+
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
57+
allow any name in the type hints.
5358
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
5459
"""
5560

@@ -70,11 +75,12 @@ class GeminiModel(Model):
7075
Apart from `__init__`, all methods are private or match those of the base class.
7176
"""
7277

73-
model_name: GeminiModelName
74-
http_client: AsyncHTTPClient
78+
http_client: AsyncHTTPClient = field(repr=False)
7579

76-
_auth: AuthProtocol | None
77-
_url: str | None
80+
_model_name: GeminiModelName = field(repr=False)
81+
_auth: AuthProtocol | None = field(repr=False)
82+
_url: str | None = field(repr=False)
83+
_system: str | None = field(default='google-gla', repr=False)
7884

7985
def __init__(
8086
self,
@@ -95,7 +101,7 @@ def __init__(
95101
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
96102
`model` is substituted with the model name, and `function` is added to the end of the URL.
97103
"""
98-
self.model_name = model_name
104+
self._model_name = model_name
99105
if api_key is None:
100106
if env_api_key := os.getenv('GEMINI_API_KEY'):
101107
api_key = env_api_key
@@ -115,9 +121,6 @@ def url(self) -> str:
115121
assert self._url is not None, 'URL not initialized'
116122
return self._url
117123

118-
def name(self) -> str:
119-
return f'google-gla:{self.model_name}'
120-
121124
async def request(
122125
self,
123126
messages: list[ModelMessage],
@@ -228,7 +231,7 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
228231
else:
229232
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
230233
parts = response['candidates'][0]['content']['parts']
231-
return _process_response_from_parts(parts, model_name=self.model_name)
234+
return _process_response_from_parts(parts, model_name=self._model_name)
232235

233236
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
234237
"""Process a streamed response, and prepare a streaming response to return."""
@@ -251,7 +254,7 @@ async def _process_streamed_response(self, http_response: HTTPResponse) -> Strea
251254
if start_response is None:
252255
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
253256

254-
return GeminiStreamedResponse(_model_name=self.model_name, _content=content, _stream=aiter_bytes)
257+
return GeminiStreamedResponse(_model_name=self._model_name, _content=content, _stream=aiter_bytes)
255258

256259
@classmethod
257260
def _message_to_gemini_content(

0 commit comments

Comments
 (0)