Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ async def run( # noqa: C901
ctx.state.message_history = messages
ctx.deps.new_message_index = len(messages)

# Validate that message history starts with a user message
if messages and isinstance(messages[0], _messages.ModelResponse):
raise exceptions.UserError(
'Message history cannot start with a `ModelResponse`. Conversations must begin with a user message.'
)

if self.deferred_tool_results is not None:
return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx)

Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def test_input_format(transformers_multimodal_model: OutlinesModel, binary_image

# unsupported: tool calls
tool_call_message_history: list[ModelMessage] = [
ModelRequest(parts=[UserPromptPart(content='some user prompt')]),
ModelResponse(parts=[ToolCallPart(tool_call_id='1', tool_name='get_location')]),
ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')]),
]
Expand All @@ -588,7 +589,8 @@ def test_input_format(transformers_multimodal_model: OutlinesModel, binary_image

# unsupported: non-image file parts
file_part_message_history: list[ModelMessage] = [
ModelResponse(parts=[FilePart(content=BinaryContent(data=b'test', media_type='text/plain'))])
ModelRequest(parts=[UserPromptPart(content='some user prompt')]),
ModelResponse(parts=[FilePart(content=BinaryContent(data=b'test', media_type='text/plain'))]),
]
with pytest.raises(
UserError, match='File parts other than `BinaryImage` are not supported for Outlines models yet.'
Expand Down
105 changes: 105 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6125,3 +6125,108 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
]
)
assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')


def test_message_history_cannot_start_with_model_response():
"""Test that message history starting with ModelResponse raises UserError."""

def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover

agent = Agent(FunctionModel(simple_response))

invalid_history = [
ModelResponse(parts=[TextPart(content='ai response')]),
]

with pytest.raises(
UserError,
match='Message history cannot start with a `ModelResponse`.',
):
agent.run_sync('hello', message_history=invalid_history)


async def test_message_history_starts_with_model_request():
"""Test that valid history starting with ModelRequest works correctly."""

def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart('ok here is text')])

agent = Agent(FunctionModel(llm))

valid_history = [
ModelRequest(parts=[UserPromptPart(content='Hello')]),
ModelResponse(parts=[TextPart(content='Hi there!')]),
]

# Should not raise error - valid history starting with ModelRequest
async with agent.iter('How are you?', message_history=valid_history) as run:
async for _ in run:
pass
# Verify messages are processed correctly
all_messages = run.all_messages()
assert len(all_messages) >= 3 # History + new request + response
assert isinstance(all_messages[0], ModelRequest) # First message is ModelRequest


async def test_empty_message_history_is_valid():
"""Test that empty message history works fine."""

def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart('response text')])

agent = Agent(FunctionModel(llm))

# Empty history should work - should not raise error
async with agent.iter('hello', message_history=[]) as run:
async for _ in run:
pass
all_messages = run.all_messages()
assert len(all_messages) >= 2 # Request + response
assert isinstance(all_messages[0], ModelRequest)


async def test_message_history_with_multiple_messages():
"""Test that history with multiple messages starting correctly works."""

def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart('final response')])

agent = Agent(FunctionModel(llm))

valid_history = [
ModelRequest(parts=[UserPromptPart(content='First')]),
ModelResponse(parts=[TextPart(content='Response 1')]),
ModelRequest(parts=[UserPromptPart(content='Second')]),
ModelResponse(parts=[TextPart(content='Response 2')]),
]

async with agent.iter('Third message', message_history=valid_history) as run:
async for _ in run:
pass
# Verify the history is preserved and new message is added
all_messages = run.all_messages()
assert len(all_messages) >= 5 # 4 from history + at least 1 new
assert isinstance(all_messages[0], ModelRequest)
assert isinstance(all_messages[-1], ModelResponse)


def test_validation_happens_after_cleaning():
"""Test that validation catches issues even after message cleaning."""

def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover

agent = Agent(FunctionModel(simple_response))

# Even if cleaning merges messages, first should still be checked
invalid_history = [
ModelResponse(parts=[TextPart(content='response 1')]),
ModelResponse(parts=[TextPart(content='response 2')]), # Would be merged
]

with pytest.raises(
UserError,
match='Message history cannot start with a `ModelResponse`.',
):
agent.run_sync('hello', message_history=invalid_history)