Skip to content

Commit aea25e3

Browse files
authored
Better test functionality (#50)
1 parent d4a52ef commit aea25e3

File tree

11 files changed

+189
-44
lines changed

11 files changed

+189
-44
lines changed

docs/api/agent.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- run_stream
1010
- model
1111
- override_deps
12+
- override_model
1213
- last_run_messages
1314
- system_prompt
1415
- retriever_plain

docs/api/models/base.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
# `pydantic_ai.models`
22

33
::: pydantic_ai.models
4+
options:
5+
members:
6+
- Model
7+
- AgentModel
8+
- AbstractToolDefinition
9+
- StreamTextResponse
10+
- StreamStructuredResponse
11+
- ALLOW_MODEL_REQUESTS
12+
- check_allow_model_requests
13+
- override_allow_model_requests

pydantic_ai/agent.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
import logfire_api
1010
from typing_extensions import assert_never
1111

12-
from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result
12+
from . import (
13+
_result,
14+
_retriever as _r,
15+
_system_prompt,
16+
_utils,
17+
exceptions,
18+
messages as _messages,
19+
models,
20+
result,
21+
)
1322
from .dependencies import AgentDeps, RetrieverContextFunc, RetrieverParams, RetrieverPlainFunc
1423
from .result import ResultData
1524

@@ -23,6 +32,7 @@
2332
'openai:gpt-3.5-turbo',
2433
'gemini-1.5-flash',
2534
'gemini-1.5-pro',
35+
'test',
2636
]
2737
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
2838
@@ -40,7 +50,7 @@ class Agent(Generic[AgentDeps, ResultData]):
4050
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM."""
4151

4252
# dataclass fields mostly for my sanity — knowing what attributes are available
43-
model: models.Model | None
53+
model: models.Model | KnownModelName | None
4454
"""The default model configured for this agent."""
4555
_result_schema: _result.ResultSchema[ResultData] | None
4656
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
@@ -52,7 +62,8 @@ class Agent(Generic[AgentDeps, ResultData]):
5262
_deps_type: type[AgentDeps]
5363
_max_result_retries: int
5464
_current_result_retry: int
55-
_override_deps_stack: list[AgentDeps]
65+
_override_deps: _utils.Option[AgentDeps] = None
66+
_override_model: _utils.Option[models.Model] = None
5667
last_run_messages: list[_messages.Message] | None = None
5768
"""The messages from the last run, useful when a run raised an exception.
5869
@@ -70,6 +81,7 @@ def __init__(
7081
result_tool_name: str = 'final_result',
7182
result_tool_description: str | None = None,
7283
result_retries: int | None = None,
84+
defer_model_check: bool = False,
7385
):
7486
"""Create an agent.
7587
@@ -87,8 +99,16 @@ def __init__(
8799
result_tool_name: The name of the tool to use for the final result.
88100
result_tool_description: The description of the final result tool.
89101
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
102+
defer_model_check: by default, if you provide a [named][pydantic_ai.agent.KnownModelName] model,
103+
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
104+
which checks for the necessary environment variables. Set this to `false`
105+
to defer the evaluation until the first run. Useful if you want to
106+
[override the model][pydantic_ai.Agent.override_model] for testing.
90107
"""
91-
self.model = models.infer_model(model) if model is not None else None
108+
if model is None or defer_model_check:
109+
self.model = model
110+
else:
111+
self.model = models.infer_model(model)
92112

93113
self._result_schema = _result.ResultSchema[result_type].build(
94114
result_type, result_tool_name, result_tool_description
@@ -104,7 +124,6 @@ def __init__(
104124
self._max_result_retries = result_retries if result_retries is not None else retries
105125
self._current_result_retry = 0
106126
self._result_validators = []
107-
self._override_deps_stack = []
108127

109128
async def run(
110129
self,
@@ -281,11 +300,26 @@ def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]:
281300
Args:
282301
overriding_deps: The dependencies to use instead of the dependencies passed to the agent run.
283302
"""
284-
self._override_deps_stack.append(overriding_deps)
303+
override_deps_before = self._override_deps
304+
self._override_deps = _utils.Some(overriding_deps)
285305
try:
286306
yield
287307
finally:
288-
self._override_deps_stack.pop()
308+
self._override_deps = override_deps_before
309+
310+
@contextmanager
311+
def override_model(self, overriding_model: models.Model | KnownModelName) -> Iterator[None]:
312+
"""Context manager to temporarily override the model used by the agent.
313+
314+
Args:
315+
overriding_model: The model to use instead of the model passed to the agent run.
316+
"""
317+
override_model_before = self._override_model
318+
self._override_model = _utils.Some(models.infer_model(overriding_model))
319+
try:
320+
yield
321+
finally:
322+
self._override_model = override_model_before
289323

290324
def system_prompt(
291325
self, func: _system_prompt.SystemPromptFunc[AgentDeps]
@@ -386,11 +420,20 @@ async def _get_agent_model(
386420
a tuple of `(model used, custom_model if any, agent_model)`
387421
"""
388422
model_: models.Model
389-
if model is not None:
423+
if some_model := self._override_model:
424+
# we don't want `override_model()` to cover up errors from the model not being defined, hence this check
425+
if model is None and self.model is None:
426+
raise exceptions.UserError(
427+
'`model` must be set either when creating the agent or when calling it. '
428+
'(Even when `override_model()` is customizing the model that will actually be called)'
429+
)
430+
model_ = some_model.value
431+
custom_model = None
432+
elif model is not None:
390433
custom_model = model_ = models.infer_model(model)
391434
elif self.model is not None:
392435
# noinspection PyTypeChecker
393-
model_ = self.model
436+
model_ = self.model = models.infer_model(self.model)
394437
custom_model = None
395438
else:
396439
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')
@@ -573,9 +616,9 @@ def _get_deps(self, deps: AgentDeps) -> AgentDeps:
573616
574617
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
575618
"""
576-
try:
577-
return self._override_deps_stack[-1]
578-
except IndexError:
619+
if some_deps := self._override_deps:
620+
return some_deps.value
621+
else:
579622
return deps
580623

581624

pydantic_ai/models/__init__.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from __future__ import annotations as _annotations
88

99
from abc import ABC, abstractmethod
10-
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
11-
from contextlib import asynccontextmanager
10+
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence
11+
from contextlib import asynccontextmanager, contextmanager
1212
from datetime import datetime
1313
from functools import cache
1414
from typing import TYPE_CHECKING, Protocol, Union
@@ -151,10 +151,54 @@ def timestamp(self) -> datetime:
151151
EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
152152

153153

154+
ALLOW_MODEL_REQUESTS = False
155+
"""Whether to allow requests to models.
156+
157+
This global setting allows you to disable request to most models, e.g. to make sure you don't accidentally
158+
make costly requests to a model during tests.
159+
160+
The testing models [`TestModel`][pydantic_ai.models.test.TestModel] and
161+
[`FunctionModel`][pydantic_ai.models.function.FunctionModel] are no affected by this setting.
162+
"""
163+
164+
165+
def check_allow_model_requests() -> None:
166+
"""Check if model requests are allowed.
167+
168+
If you're defining your own models that have cost or latency associated with their use, you should call this in
169+
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
170+
171+
Raises:
172+
RuntimeError: If model requests are not allowed.
173+
"""
174+
if not ALLOW_MODEL_REQUESTS:
175+
raise RuntimeError('Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False')
176+
177+
178+
@contextmanager
179+
def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
180+
"""Context manager to temporarily override [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS].
181+
182+
Args:
183+
allow_model_requests: Whether to allow model requests within the context.
184+
"""
185+
global ALLOW_MODEL_REQUESTS
186+
old_value = ALLOW_MODEL_REQUESTS
187+
ALLOW_MODEL_REQUESTS = allow_model_requests # pyright: ignore[reportConstantRedefinition]
188+
try:
189+
yield
190+
finally:
191+
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]
192+
193+
154194
def infer_model(model: Model | KnownModelName) -> Model:
155195
"""Infer the model from the name."""
156196
if isinstance(model, Model):
157197
return model
198+
elif model == 'test':
199+
from .test import TestModel
200+
201+
return TestModel()
158202
elif model.startswith('openai:'):
159203
from .openai import OpenAIModel
160204

pydantic_ai/models/gemini.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
StreamStructuredResponse,
5656
StreamTextResponse,
5757
cached_async_http_client,
58+
check_allow_model_requests,
5859
)
5960

6061
GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
@@ -113,6 +114,7 @@ def agent_model(
113114
allow_text_result: bool,
114115
result_tools: Sequence[AbstractToolDefinition] | None,
115116
) -> GeminiAgentModel:
117+
check_allow_model_requests()
116118
tools = [_function_from_abstract_tool(t) for t in retrievers.values()]
117119
if result_tools is not None:
118120
tools += [_function_from_abstract_tool(t) for t in result_tools]

pydantic_ai/models/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
StreamStructuredResponse,
3434
StreamTextResponse,
3535
cached_async_http_client,
36+
check_allow_model_requests,
3637
)
3738

3839

@@ -85,6 +86,7 @@ def agent_model(
8586
allow_text_result: bool,
8687
result_tools: Sequence[AbstractToolDefinition] | None,
8788
) -> AgentModel:
89+
check_allow_model_requests()
8890
tools = [self._map_tool_definition(r) for r in retrievers.values()]
8991
if result_tools is not None:
9092
tools += [self._map_tool_definition(r) for r in result_tools]

tests/conftest.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import secrets
77
import sys
8+
from collections.abc import Iterator
89
from datetime import datetime
910
from pathlib import Path
1011
from types import ModuleType
@@ -15,8 +16,13 @@
1516
from _pytest.assertion.rewrite import AssertionRewritingHook
1617
from typing_extensions import TypeAlias
1718

19+
import pydantic_ai.models
20+
1821
__all__ = 'IsNow', 'TestEnv', 'ClientWithHandler'
1922

23+
24+
pydantic_ai.models.ALLOW_MODEL_REQUESTS = False
25+
2026
if TYPE_CHECKING:
2127

2228
def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
@@ -38,35 +44,43 @@ class TestEnv:
3844
__test__ = False
3945

4046
def __init__(self):
41-
self.envars: set[str] = set()
47+
self.envars: dict[str, str | None] = {}
4248

4349
def set(self, name: str, value: str) -> None:
44-
self.envars.add(name)
50+
self.envars[name] = os.getenv(name)
4551
os.environ[name] = value
4652

47-
def pop(self, name: str) -> None: # pragma: no cover
48-
self.envars.remove(name)
49-
os.environ.pop(name)
53+
def remove(self, name: str) -> None:
54+
self.envars[name] = os.environ.pop(name, None)
5055

51-
def clear(self) -> None:
52-
for n in self.envars:
53-
os.environ.pop(n)
56+
def reset(self) -> None:
57+
for name, value in self.envars.items():
58+
if value is None:
59+
os.environ.pop(name, None)
60+
else:
61+
os.environ[name] = value
5462

5563

5664
@pytest.fixture
57-
def env():
65+
def env() -> Iterator[TestEnv]:
5866
test_env = TestEnv()
5967

6068
yield test_env
6169

62-
test_env.clear()
70+
test_env.reset()
6371

6472

6573
@pytest.fixture
6674
def anyio_backend():
6775
return 'asyncio'
6876

6977

78+
@pytest.fixture
79+
def allow_model_requests():
80+
with pydantic_ai.models.override_allow_model_requests(True):
81+
yield
82+
83+
7084
@pytest.fixture
7185
async def client_with_handler():
7286
client: httpx.AsyncClient | None = None

0 commit comments

Comments
 (0)