Skip to content

Commit b65db2a

Browse files
committed
Merge main
2 parents d38464e + 322e092 commit b65db2a

File tree

10 files changed

+216
-13
lines changed

10 files changed

+216
-13
lines changed

docs/evals.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,12 @@ async def main():
653653
print(output_file.read_text())
654654
"""
655655
# yaml-language-server: $schema=questions_cases_schema.json
656+
name: null
656657
cases:
657658
- name: Easy Capital Question
658659
inputs:
659660
question: What is the capital of France?
661+
context: null
660662
metadata:
661663
difficulty: easy
662664
category: Geography
@@ -668,6 +670,7 @@ async def main():
668670
- name: Challenging Landmark Question
669671
inputs:
670672
question: Which world-famous landmark is located on the banks of the Seine River?
673+
context: null
671674
metadata:
672675
difficulty: hard
673676
category: Landmarks
@@ -676,6 +679,7 @@ async def main():
676679
confidence: 0.9
677680
evaluators:
678681
- EqualsExpected
682+
evaluators: []
679683
"""
680684
```
681685

@@ -713,11 +717,13 @@ async def main():
713717
"""
714718
{
715719
"$schema": "questions_cases_schema.json",
720+
"name": null,
716721
"cases": [
717722
{
718723
"name": "Easy Capital Question",
719724
"inputs": {
720-
"question": "What is the capital of France?"
725+
"question": "What is the capital of France?",
726+
"context": null
721727
},
722728
"metadata": {
723729
"difficulty": "easy",
@@ -734,7 +740,8 @@ async def main():
734740
{
735741
"name": "Challenging Landmark Question",
736742
"inputs": {
737-
"question": "Which world-famous landmark is located on the banks of the Seine River?"
743+
"question": "Which world-famous landmark is located on the banks of the Seine River?",
744+
"context": null
738745
},
739746
"metadata": {
740747
"difficulty": "hard",
@@ -748,7 +755,8 @@ async def main():
748755
"EqualsExpected"
749756
]
750757
}
751-
]
758+
],
759+
"evaluators": []
752760
}
753761
"""
754762
```

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,15 +459,13 @@ async def _prepare_request(
459459

460460
original_history = ctx.state.message_history[:]
461461
message_history = await _process_message_history(original_history, ctx.deps.history_processors, run_context)
462-
# Never merge the new `ModelRequest` with the one preceding it, to keep `new_messages()` from accidentally including part of the existing message history
463-
message_history = [*_clean_message_history(message_history[:-1]), message_history[-1]]
464462
# `ctx.state.message_history` is the same list used by `capture_run_messages`, so we should replace its contents, not the reference
465463
ctx.state.message_history[:] = message_history
466464
# Update the new message index to ensure `result.new_messages()` returns the correct messages
467465
ctx.deps.new_message_index -= len(original_history) - len(message_history)
468466

469-
# Do one more cleaning pass to merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
470-
# but don't store it in the message history on state.
467+
# Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
468+
# but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries.
471469
# See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary
472470
message_history = _clean_message_history(message_history)
473471

pydantic_ai_slim/pydantic_ai/format_prompt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Iterable, Iterator, Mapping
44
from dataclasses import asdict, dataclass, field, fields, is_dataclass
55
from datetime import date
6+
from enum import Enum
67
from typing import Any, Literal
78
from xml.etree import ElementTree
89

@@ -26,8 +27,8 @@ def format_as_xml(
2627
This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
2728
rather than JSON etc.
2829
29-
Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
30-
`Iterable`, `dataclass`, and `BaseModel`.
30+
Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Enum`,
31+
`Mapping`, `Iterable`, `dataclass`, and `BaseModel`.
3132
3233
Args:
3334
obj: Python Object to serialize to XML.
@@ -101,7 +102,7 @@ def _to_xml(self, value: Any, path: str, tag: str | None = None) -> ElementTree.
101102
element.text = value
102103
elif isinstance(value, bytes | bytearray):
103104
element.text = value.decode(errors='ignore')
104-
elif isinstance(value, bool | int | float):
105+
elif isinstance(value, bool | int | float | Enum):
105106
element.text = str(value)
106107
elif isinstance(value, date):
107108
element.text = value.isoformat()

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,20 @@
141141
'google-gla:gemini-2.0-flash',
142142
'google-gla:gemini-2.0-flash-lite',
143143
'google-gla:gemini-2.5-flash',
144+
'google-gla:gemini-2.5-flash-preview-09-2025',
145+
'google-gla:gemini-flash-latest',
144146
'google-gla:gemini-2.5-flash-lite',
147+
'google-gla:gemini-2.5-flash-lite-preview-09-2025',
148+
'google-gla:gemini-flash-lite-latest',
145149
'google-gla:gemini-2.5-pro',
146150
'google-vertex:gemini-2.0-flash',
147151
'google-vertex:gemini-2.0-flash-lite',
148152
'google-vertex:gemini-2.5-flash',
153+
'google-vertex:gemini-2.5-flash-preview-09-2025',
154+
'google-vertex:gemini-flash-latest',
149155
'google-vertex:gemini-2.5-flash-lite',
156+
'google-vertex:gemini-2.5-flash-lite-preview-09-2025',
157+
'google-vertex:gemini-flash-lite-latest',
150158
'google-vertex:gemini-2.5-pro',
151159
'grok:grok-4',
152160
'grok:grok-4-0709',

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
'gemini-2.0-flash',
4747
'gemini-2.0-flash-lite',
4848
'gemini-2.5-flash',
49+
'gemini-2.5-flash-preview-09-2025',
4950
'gemini-2.5-flash-lite',
51+
'gemini-2.5-flash-lite-preview-09-2025',
52+
'gemini-flash-latest',
53+
'gemini-flash-lite-latest',
5054
'gemini-2.5-pro',
5155
]
5256
"""Latest Gemini models."""

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@
9595
'gemini-2.0-flash',
9696
'gemini-2.0-flash-lite',
9797
'gemini-2.5-flash',
98+
'gemini-2.5-flash-preview-09-2025',
99+
'gemini-flash-latest',
98100
'gemini-2.5-flash-lite',
101+
'gemini-2.5-flash-lite-preview-09-2025',
102+
'gemini-flash-lite-latest',
99103
'gemini-2.5-pro',
100104
]
101105
"""Latest Gemini models."""

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,15 +646,15 @@ def to_file(
646646

647647
context: dict[str, Any] = {'use_short_form': True}
648648
if fmt == 'yaml':
649-
dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context)
649+
dumped_data = self.model_dump(mode='json', by_alias=True, context=context)
650650
content = yaml.dump(dumped_data, sort_keys=False)
651651
if schema_ref: # pragma: no branch
652652
yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}'
653653
content = f'{yaml_language_server_line}\n{content}'
654654
path.write_text(content)
655655
else:
656656
context['$schema'] = schema_ref
657-
json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context)
657+
json_data = self.model_dump_json(indent=2, by_alias=True, context=context)
658658
path.write_text(json_data + '\n')
659659

660660
@classmethod
@@ -724,6 +724,7 @@ class Case(BaseModel, extra='forbid'): # pyright: ignore[reportUnusedClass] #
724724
evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007
725725

726726
class Dataset(BaseModel, extra='forbid'):
727+
name: str | None = None
727728
cases: list[Case]
728729
if evaluator_schema_types: # pragma: no branch
729730
evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007

tests/evals/test_dataset.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
from dataclasses import dataclass, field
66
from pathlib import Path
7-
from typing import Any, cast
7+
from typing import Any, Literal, cast
88

99
import pytest
1010
import yaml
@@ -864,6 +864,38 @@ async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOut
864864
assert (tmp_path / schema).exists()
865865

866866

867+
def test_serializing_parts_with_discriminators(tmp_path: Path):
868+
class Foo(BaseModel):
869+
foo: str
870+
kind: Literal['foo'] = 'foo'
871+
872+
class Bar(BaseModel):
873+
bar: str
874+
kind: Literal['bar'] = 'bar'
875+
876+
items = [Foo(foo='foo'), Bar(bar='bar')]
877+
878+
dataset = Dataset[list[Foo | Bar]](cases=[Case(inputs=items)])
879+
yaml_path = tmp_path / 'test_cases.yaml'
880+
dataset.to_file(yaml_path)
881+
882+
loaded_dataset = Dataset[list[Foo | Bar]].from_file(yaml_path)
883+
assert loaded_dataset == snapshot(
884+
Dataset(
885+
name='test_cases',
886+
cases=[
887+
Case(
888+
name=None,
889+
inputs=[
890+
Foo(foo='foo'),
891+
Bar(bar='bar'),
892+
],
893+
)
894+
],
895+
)
896+
)
897+
898+
867899
def test_serialization_errors(tmp_path: Path):
868900
with pytest.raises(ValueError) as exc_info:
869901
Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(tmp_path / 'test_cases.abc')

tests/test_agent.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5377,3 +5377,137 @@ def dynamic_instr() -> str:
53775377
sys_texts = [p.content for p in req.parts if isinstance(p, SystemPromptPart)]
53785378
# The dynamic system prompt should still be present since overrides target instructions only
53795379
assert dynamic_value in sys_texts
5380+
5381+
5382+
def test_continue_conversation_that_ended_in_output_tool_call(allow_model_requests: None):
5383+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
5384+
if any(isinstance(p, ToolReturnPart) and p.tool_name == 'roll_dice' for p in messages[-1].parts):
5385+
return ModelResponse(
5386+
parts=[
5387+
ToolCallPart(
5388+
tool_name='final_result',
5389+
args={'dice_roll': 4},
5390+
tool_call_id='pyd_ai_tool_call_id__final_result',
5391+
)
5392+
]
5393+
)
5394+
return ModelResponse(
5395+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')]
5396+
)
5397+
5398+
class Result(BaseModel):
5399+
dice_roll: int
5400+
5401+
agent = Agent(FunctionModel(llm), output_type=Result)
5402+
5403+
@agent.tool_plain
5404+
def roll_dice() -> int:
5405+
return 4
5406+
5407+
result = agent.run_sync('Roll me a dice.')
5408+
messages = result.all_messages()
5409+
assert messages == snapshot(
5410+
[
5411+
ModelRequest(
5412+
parts=[
5413+
UserPromptPart(
5414+
content='Roll me a dice.',
5415+
timestamp=IsDatetime(),
5416+
)
5417+
]
5418+
),
5419+
ModelResponse(
5420+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')],
5421+
usage=RequestUsage(input_tokens=55, output_tokens=2),
5422+
model_name='function:llm:',
5423+
timestamp=IsDatetime(),
5424+
),
5425+
ModelRequest(
5426+
parts=[
5427+
ToolReturnPart(
5428+
tool_name='roll_dice',
5429+
content=4,
5430+
tool_call_id='pyd_ai_tool_call_id__roll_dice',
5431+
timestamp=IsDatetime(),
5432+
)
5433+
]
5434+
),
5435+
ModelResponse(
5436+
parts=[
5437+
ToolCallPart(
5438+
tool_name='final_result',
5439+
args={'dice_roll': 4},
5440+
tool_call_id='pyd_ai_tool_call_id__final_result',
5441+
)
5442+
],
5443+
usage=RequestUsage(input_tokens=56, output_tokens=6),
5444+
model_name='function:llm:',
5445+
timestamp=IsDatetime(),
5446+
),
5447+
ModelRequest(
5448+
parts=[
5449+
ToolReturnPart(
5450+
tool_name='final_result',
5451+
content='Final result processed.',
5452+
tool_call_id='pyd_ai_tool_call_id__final_result',
5453+
timestamp=IsDatetime(),
5454+
)
5455+
]
5456+
),
5457+
]
5458+
)
5459+
5460+
result = agent.run_sync('Roll me a dice again.', message_history=messages)
5461+
new_messages = result.new_messages()
5462+
assert new_messages == snapshot(
5463+
[
5464+
ModelRequest(
5465+
parts=[
5466+
UserPromptPart(
5467+
content='Roll me a dice again.',
5468+
timestamp=IsDatetime(),
5469+
)
5470+
]
5471+
),
5472+
ModelResponse(
5473+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')],
5474+
usage=RequestUsage(input_tokens=66, output_tokens=8),
5475+
model_name='function:llm:',
5476+
timestamp=IsDatetime(),
5477+
),
5478+
ModelRequest(
5479+
parts=[
5480+
ToolReturnPart(
5481+
tool_name='roll_dice',
5482+
content=4,
5483+
tool_call_id='pyd_ai_tool_call_id__roll_dice',
5484+
timestamp=IsDatetime(),
5485+
)
5486+
]
5487+
),
5488+
ModelResponse(
5489+
parts=[
5490+
ToolCallPart(
5491+
tool_name='final_result',
5492+
args={'dice_roll': 4},
5493+
tool_call_id='pyd_ai_tool_call_id__final_result',
5494+
)
5495+
],
5496+
usage=RequestUsage(input_tokens=67, output_tokens=12),
5497+
model_name='function:llm:',
5498+
timestamp=IsDatetime(),
5499+
),
5500+
ModelRequest(
5501+
parts=[
5502+
ToolReturnPart(
5503+
tool_name='final_result',
5504+
content='Final result processed.',
5505+
tool_call_id='pyd_ai_tool_call_id__final_result',
5506+
timestamp=IsDatetime(),
5507+
)
5508+
]
5509+
),
5510+
]
5511+
)
5512+
5513+
assert not any(isinstance(p, ToolReturnPart) and p.tool_name == 'final_result' for p in new_messages[0].parts)

0 commit comments

Comments
 (0)