9
9
import logfire_api
10
10
from typing_extensions import assert_never
11
11
12
- from . import _result , _retriever as _r , _system_prompt , _utils , exceptions , messages as _messages , models , result
12
+ from . import (
13
+ _result ,
14
+ _retriever as _r ,
15
+ _system_prompt ,
16
+ _utils ,
17
+ exceptions ,
18
+ messages as _messages ,
19
+ models ,
20
+ result ,
21
+ )
13
22
from .dependencies import AgentDeps , RetrieverContextFunc , RetrieverParams , RetrieverPlainFunc
14
23
from .result import ResultData
15
24
23
32
'openai:gpt-3.5-turbo' ,
24
33
'gemini-1.5-flash' ,
25
34
'gemini-1.5-pro' ,
35
+ 'test' ,
26
36
]
27
37
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
28
38
@@ -40,7 +50,7 @@ class Agent(Generic[AgentDeps, ResultData]):
40
50
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM."""
41
51
42
52
# dataclass fields mostly for my sanity — knowing what attributes are available
43
- model : models .Model | None
53
+ model : models .Model | KnownModelName | None
44
54
"""The default model configured for this agent."""
45
55
_result_schema : _result .ResultSchema [ResultData ] | None
46
56
_result_validators : list [_result .ResultValidator [AgentDeps , ResultData ]]
@@ -52,7 +62,8 @@ class Agent(Generic[AgentDeps, ResultData]):
52
62
_deps_type : type [AgentDeps ]
53
63
_max_result_retries : int
54
64
_current_result_retry : int
55
- _override_deps_stack : list [AgentDeps ]
65
+ _override_deps : _utils .Option [AgentDeps ] = None
66
+ _override_model : _utils .Option [models .Model ] = None
56
67
last_run_messages : list [_messages .Message ] | None = None
57
68
"""The messages from the last run, useful when a run raised an exception.
58
69
@@ -70,6 +81,7 @@ def __init__(
70
81
result_tool_name : str = 'final_result' ,
71
82
result_tool_description : str | None = None ,
72
83
result_retries : int | None = None ,
84
+ defer_model_check : bool = False ,
73
85
):
74
86
"""Create an agent.
75
87
@@ -87,8 +99,16 @@ def __init__(
87
99
result_tool_name: The name of the tool to use for the final result.
88
100
result_tool_description: The description of the final result tool.
89
101
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
102
+ defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model,
103
+ it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
104
+ which checks for the necessary environment variables. Set this to `false`
105
+ to defer the evaluation until the first run. Useful if you want to
106
+ [override the model][pydantic_ai.Agent.override_model] for testing.
90
107
"""
91
- self .model = models .infer_model (model ) if model is not None else None
108
+ if model is None or defer_model_check :
109
+ self .model = model
110
+ else :
111
+ self .model = models .infer_model (model )
92
112
93
113
self ._result_schema = _result .ResultSchema [result_type ].build (
94
114
result_type , result_tool_name , result_tool_description
@@ -104,7 +124,6 @@ def __init__(
104
124
self ._max_result_retries = result_retries if result_retries is not None else retries
105
125
self ._current_result_retry = 0
106
126
self ._result_validators = []
107
- self ._override_deps_stack = []
108
127
109
128
async def run (
110
129
self ,
@@ -281,11 +300,26 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
281
300
Args:
282
301
overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
283
302
"""
284
- self ._override_deps_stack .append (overriding_deps )
303
+ override_deps_before = self ._override_deps
304
+ self ._override_deps = _utils .Some (overriding_deps )
285
305
try :
286
306
yield
287
307
finally :
288
- self ._override_deps_stack .pop ()
308
+ self ._override_deps = override_deps_before
309
+
310
+ @contextmanager
311
+ def override_model (self , overriding_model : models .Model | KnownModelName ) -> Iterator [None ]:
312
+ """Context manager to temporarily override the model used by the agent.
313
+
314
+ Args:
315
+ overriding_model: The model to use instead of the model passed to the agent run.
316
+ """
317
+ override_model_before = self ._override_model
318
+ self ._override_model = _utils .Some (models .infer_model (overriding_model ))
319
+ try :
320
+ yield
321
+ finally :
322
+ self ._override_model = override_model_before
289
323
290
324
def system_prompt (
291
325
self , func : _system_prompt .SystemPromptFunc [AgentDeps ]
@@ -386,11 +420,20 @@ async def _get_agent_model(
386
420
a tuple of `(model used, custom_model if any, agent_model)`
387
421
"""
388
422
model_ : models .Model
389
- if model is not None :
423
+ if some_model := self ._override_model :
424
+ # we don't want `override_model()` to cover up errors from the model not being defined, hence this check
425
+ if model is None and self .model is None :
426
+ raise exceptions .UserError (
427
+ '`model` must be set either when creating the agent or when calling it. '
428
+ '(Even when `override_model()` is customizing the model that will actually be called)'
429
+ )
430
+ model_ = some_model .value
431
+ custom_model = None
432
+ elif model is not None :
390
433
custom_model = model_ = models .infer_model (model )
391
434
elif self .model is not None :
392
435
# noinspection PyTypeChecker
393
- model_ = self .model
436
+ model_ = self .model = models . infer_model ( self . model )
394
437
custom_model = None
395
438
else :
396
439
raise exceptions .UserError ('`model` must be set either when creating the agent or when calling it.' )
@@ -573,9 +616,9 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps:
573
616
574
617
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
575
618
"""
576
- try :
577
- return self . _override_deps_stack [ - 1 ]
578
- except IndexError :
619
+ if some_deps := self . _override_deps :
620
+ return some_deps . value
621
+ else :
579
622
return deps
580
623
581
624
0 commit comments