Skip to content

Commit 7740f09

Browse files
authored
Support different content inputs in TestModel (#1015)
1 parent 28cc2af commit 7740f09

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,16 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
282282
return 0
283283
if isinstance(content, str):
284284
return len(re.split(r'[\s",.:]+', content.strip()))
285-
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
286285
else: # pragma: no cover
287-
assert isinstance(content, (AudioUrl, ImageUrl, BinaryContent))
288-
return 0
286+
tokens = 0
287+
for part in content:
288+
if isinstance(part, str):
289+
tokens += len(re.split(r'[\s",.:]+', part.strip()))
290+
# TODO(Marcelo): We need to study how we can estimate the tokens for these types of content.
291+
if isinstance(part, (AudioUrl, ImageUrl)):
292+
tokens += 0
293+
elif isinstance(part, BinaryContent):
294+
tokens += len(part.data)
295+
else:
296+
tokens += 0
297+
return tokens

tests/models/test_model_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from pydantic_ai import Agent, ModelRetry, RunContext
1414
from pydantic_ai.exceptions import UnexpectedModelBehavior
1515
from pydantic_ai.messages import (
16+
AudioUrl,
17+
BinaryContent,
18+
ImageUrl,
1619
ModelRequest,
1720
ModelResponse,
1821
RetryPromptPart,
@@ -22,6 +25,7 @@
2225
UserPromptPart,
2326
)
2427
from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage]
28+
from pydantic_ai.usage import Usage
2529

2630
from ..conftest import IsNow
2731

@@ -271,3 +275,18 @@ def test_max_items():
271275
}
272276
data = _JsonSchemaTestData(json_schema).generate()
273277
assert data == snapshot([])
278+
279+
280+
@pytest.mark.parametrize(
281+
'content',
282+
[
283+
AudioUrl(url='https://example.com'),
284+
ImageUrl(url='https://example.com'),
285+
BinaryContent(data=b'', media_type='image/png'),
286+
],
287+
)
288+
def test_different_content_input(content: AudioUrl | ImageUrl | BinaryContent):
289+
agent = Agent()
290+
result = agent.run_sync('x', model=TestModel(custom_result_text='custom'))
291+
assert result.data == snapshot('custom')
292+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52))

0 commit comments

Comments
 (0)