Skip to content

Commit ecafb25

Browse files
AlexEnriqueKludex
andauthored
History processor replaces message history (#2324)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent bc9a2fd commit ecafb25

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

docs/message-history.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ custom processing logic.
334334
Pydantic AI provides a `history_processors` parameter on `Agent` that allows you to intercept and modify
335335
the message history before each model request.
336336

337+
!!! warning "History processors replace the message history"
338+
History processors replace the message history in the state with the processed messages, including the new user prompt part.
339+
This means that if you want to keep the original message history, you need to make a copy of it.
340+
337341
### Usage
338342

339343
The `history_processors` is a list of callables that take a list of
@@ -389,6 +393,9 @@ long_conversation_history: list[ModelMessage] = [] # Your long conversation his
389393
# result = agent.run_sync('What did we discuss?', message_history=long_conversation_history)
390394
```
391395

396+
!!! warning "Be careful when slicing the message history"
397+
When slicing the message history, you need to make sure that tool calls and returns are paired, otherwise the LLM may return an error. For more details, refer to [this GitHub issue](https://github.com/pydantic/pydantic-ai/issues/2050#issuecomment-3019976269).
398+
392399
#### `RunContext` parameter
393400

394401
History processors can optionally accept a [`RunContext`][pydantic_ai.tools.RunContext] parameter to access
@@ -449,6 +456,9 @@ async def summarize_old_messages(messages: list[ModelMessage]) -> list[ModelMess
449456
agent = Agent('openai:gpt-4o', history_processors=[summarize_old_messages])
450457
```
451458

459+
!!! warning "Be careful when summarizing the message history"
460+
When summarizing the message history, you need to make sure that tool calls and returns are paired, otherwise the LLM may return an error. For more details, refer to [this GitHub issue](https://github.com/pydantic/pydantic-ai/issues/2050#issuecomment-3019976269), where you can find examples of summarizing the message history.
461+
452462
### Testing History Processors
453463

454464
You can test what messages are actually sent to the model provider using

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ async def _prepare_request(
365365
model_request_parameters = await _prepare_request_parameters(ctx)
366366
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
367367

368-
message_history = await _process_message_history(
369-
ctx.state.message_history, ctx.deps.history_processors, run_context
370-
)
368+
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
371369

372370
return model_settings, model_request_parameters, message_history, run_context
373371

@@ -859,11 +857,12 @@ def build_agent_graph(
859857

860858

861859
async def _process_message_history(
862-
messages: list[_messages.ModelMessage],
860+
state: GraphAgentState,
863861
processors: Sequence[HistoryProcessor[DepsT]],
864862
run_context: RunContext[DepsT],
865863
) -> list[_messages.ModelMessage]:
866864
"""Process message history through a sequence of processors."""
865+
messages = state.message_history
867866
for processor in processors:
868867
takes_ctx = is_takes_ctx(processor)
869868

@@ -880,4 +879,7 @@ async def _process_message_history(
880879
else:
881880
sync_processor = cast(_HistoryProcessorSync, processor)
882881
messages = await run_in_executor(sync_processor, messages)
882+
883+
# Replaces the message history in the state with the processed messages
884+
state.message_history = messages
883885
return messages

tests/test_history_processor.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
from inline_snapshot import snapshot
66

77
from pydantic_ai import Agent
8-
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart
8+
from pydantic_ai.messages import (
9+
ModelMessage,
10+
ModelRequest,
11+
ModelRequestPart,
12+
ModelResponse,
13+
SystemPromptPart,
14+
TextPart,
15+
UserPromptPart,
16+
)
917
from pydantic_ai.models.function import AgentInfo, FunctionModel
1018
from pydantic_ai.tools import RunContext
1119
from pydantic_ai.usage import Usage
@@ -70,6 +78,71 @@ def no_op_history_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
7078
)
7179

7280

81+
async def test_history_processor_run_replaces_message_history(function_model: FunctionModel):
82+
"""Test that the history processor replaces the message history in the state."""
83+
84+
def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage]:
85+
# Keep the last message (last question) and add a new system prompt
86+
return messages[-1:] + [ModelRequest(parts=[SystemPromptPart(content='Processed answer')])]
87+
88+
agent = Agent(function_model, history_processors=[process_previous_answers])
89+
90+
message_history = [
91+
ModelRequest(parts=[UserPromptPart(content='Question 1')]),
92+
ModelResponse(parts=[TextPart(content='Answer 1')]),
93+
ModelRequest(parts=[UserPromptPart(content='Question 2')]),
94+
ModelResponse(parts=[TextPart(content='Answer 2')]),
95+
]
96+
97+
result = await agent.run('Question 3', message_history=message_history)
98+
assert result.all_messages() == snapshot(
99+
[
100+
ModelRequest(parts=[UserPromptPart(content='Question 3', timestamp=IsDatetime())]),
101+
ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]),
102+
ModelResponse(
103+
parts=[TextPart(content='Provider response')],
104+
usage=Usage(requests=1, request_tokens=54, response_tokens=2, total_tokens=56),
105+
model_name='function:capture_model_function:capture_model_stream_function',
106+
timestamp=IsDatetime(),
107+
),
108+
]
109+
)
110+
111+
112+
async def test_history_processor_streaming_replaces_message_history(function_model: FunctionModel):
113+
"""Test that the history processor replaces the message history in the state."""
114+
115+
def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage]:
116+
# Keep the last message (last question) and add a new system prompt
117+
return messages[-1:] + [ModelRequest(parts=[SystemPromptPart(content='Processed answer')])]
118+
119+
agent = Agent(function_model, history_processors=[process_previous_answers])
120+
121+
message_history = [
122+
ModelRequest(parts=[UserPromptPart(content='Question 1')]),
123+
ModelResponse(parts=[TextPart(content='Answer 1')]),
124+
ModelRequest(parts=[UserPromptPart(content='Question 2')]),
125+
ModelResponse(parts=[TextPart(content='Answer 2')]),
126+
]
127+
128+
async with agent.run_stream('Question 3', message_history=message_history) as result:
129+
async for _ in result.stream_text():
130+
pass
131+
132+
assert result.all_messages() == snapshot(
133+
[
134+
ModelRequest(parts=[UserPromptPart(content='Question 3', timestamp=IsDatetime())]),
135+
ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]),
136+
ModelResponse(
137+
parts=[TextPart(content='hello')],
138+
usage=Usage(request_tokens=50, response_tokens=1, total_tokens=51),
139+
model_name='function:capture_model_function:capture_model_stream_function',
140+
timestamp=IsDatetime(),
141+
),
142+
]
143+
)
144+
145+
73146
async def test_history_processor_messages_sent_to_provider(
74147
function_model: FunctionModel, received_messages: list[ModelMessage]
75148
):
@@ -90,7 +163,6 @@ def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessag
90163
assert result.all_messages() == snapshot(
91164
[
92165
ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]),
93-
ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()),
94166
ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]),
95167
ModelResponse(
96168
parts=[TextPart(content='Provider response')],

0 commit comments

Comments
 (0)