Skip to content

Commit ae82fa4

Browse files
joseadsamuelcolvin
andauthored
Adds dynamic to system_prompt decorator, allowing reevaluation (#560)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 421ed97 commit ae82fa4

File tree

6 files changed

+233
-17
lines changed

6 files changed

+233
-17
lines changed

docs/message-history.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ print(result.all_messages())
4343
ModelRequest(
4444
parts=[
4545
SystemPromptPart(
46-
content='Be a helpful assistant.', part_kind='system-prompt'
46+
content='Be a helpful assistant.',
47+
dynamic_ref=None,
48+
part_kind='system-prompt',
4749
),
4850
UserPromptPart(
4951
content='Tell me a joke.',
@@ -85,7 +87,9 @@ async def main():
8587
ModelRequest(
8688
parts=[
8789
SystemPromptPart(
88-
content='Be a helpful assistant.', part_kind='system-prompt'
90+
content='Be a helpful assistant.',
91+
dynamic_ref=None,
92+
part_kind='system-prompt',
8993
),
9094
UserPromptPart(
9195
content='Tell me a joke.',
@@ -112,7 +116,9 @@ async def main():
112116
ModelRequest(
113117
parts=[
114118
SystemPromptPart(
115-
content='Be a helpful assistant.', part_kind='system-prompt'
119+
content='Be a helpful assistant.',
120+
dynamic_ref=None,
121+
part_kind='system-prompt',
116122
),
117123
UserPromptPart(
118124
content='Tell me a joke.',
@@ -166,7 +172,9 @@ print(result2.all_messages())
166172
ModelRequest(
167173
parts=[
168174
SystemPromptPart(
169-
content='Be a helpful assistant.', part_kind='system-prompt'
175+
content='Be a helpful assistant.',
176+
dynamic_ref=None,
177+
part_kind='system-prompt',
170178
),
171179
UserPromptPart(
172180
content='Tell me a joke.',
@@ -238,7 +246,9 @@ print(result2.all_messages())
238246
ModelRequest(
239247
parts=[
240248
SystemPromptPart(
241-
content='Be a helpful assistant.', part_kind='system-prompt'
249+
content='Be a helpful assistant.',
250+
dynamic_ref=None,
251+
part_kind='system-prompt',
242252
),
243253
UserPromptPart(
244254
content='Tell me a joke.',

docs/tools.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ print(dice_result.all_messages())
7272
parts=[
7373
SystemPromptPart(
7474
content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.",
75+
dynamic_ref=None,
7576
part_kind='system-prompt',
7677
),
7778
UserPromptPart(

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
@dataclass
1313
class SystemPromptRunner(Generic[AgentDeps]):
1414
function: SystemPromptFunc[AgentDeps]
15+
dynamic: bool = False
1516
_takes_ctx: bool = field(init=False)
1617
_is_async: bool = field(init=False)
1718

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class Agent(Generic[AgentDeps, ResultData]):
107107
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
108108
_default_retries: int = dataclasses.field(repr=False)
109109
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
110+
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
111+
repr=False
112+
)
110113
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
111114
_max_result_retries: int = dataclasses.field(repr=False)
112115
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
@@ -182,6 +185,7 @@ def __init__(
182185
self._register_tool(Tool(tool))
183186
self._deps_type = deps_type
184187
self._system_prompt_functions = []
188+
self._system_prompt_dynamic_functions = {}
185189
self._max_result_retries = result_retries if result_retries is not None else retries
186190
self._result_validators = []
187191

@@ -535,17 +539,37 @@ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
535539
@overload
536540
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
537541

542+
@overload
543+
def system_prompt(
544+
self, /, *, dynamic: bool = False
545+
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
546+
538547
def system_prompt(
539-
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
540-
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
548+
self,
549+
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
550+
/,
551+
*,
552+
dynamic: bool = False,
553+
) -> (
554+
Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
555+
| _system_prompt.SystemPromptFunc[AgentDeps]
556+
):
541557
"""Decorator to register a system prompt function.
542558
543559
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
544560
Can decorate a sync or async functions.
545561
562+
The decorator can be used either bare (`agent.system_prompt`) or as a function call
563+
(`agent.system_prompt(...)`), see the examples below.
564+
546565
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
547566
the type of the function, see `tests/typed_agent.py` for tests.
548567
568+
Args:
569+
func: The function to decorate
570+
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
571+
see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
572+
549573
Example:
550574
```python
551575
from pydantic_ai import Agent, RunContext
@@ -556,17 +580,27 @@ def system_prompt(
556580
def simple_system_prompt() -> str:
557581
return 'foobar'
558582
559-
@agent.system_prompt
583+
@agent.system_prompt(dynamic=True)
560584
async def async_system_prompt(ctx: RunContext[str]) -> str:
561585
return f'{ctx.deps} is the best'
562-
563-
result = agent.run_sync('foobar', deps='spam')
564-
print(result.data)
565-
#> success (no tool calls)
566586
```
567587
"""
568-
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
569-
return func
588+
if func is None:
589+
590+
def decorator(
591+
func_: _system_prompt.SystemPromptFunc[AgentDeps],
592+
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
593+
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
594+
self._system_prompt_functions.append(runner)
595+
if dynamic:
596+
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
597+
return func_
598+
599+
return decorator
600+
else:
601+
assert not dynamic, "dynamic can't be True in this case"
602+
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
603+
return func
570604

571605
@overload
572606
def result_validator(
@@ -835,6 +869,23 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
835869
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
836870
)
837871

872+
async def _reevaluate_dynamic_prompts(
873+
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
874+
) -> None:
875+
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
876+
# Only proceed if there's at least one dynamic runner.
877+
if self._system_prompt_dynamic_functions:
878+
for msg in messages:
879+
if isinstance(msg, _messages.ModelRequest):
880+
for i, part in enumerate(msg.parts):
881+
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
882+
# Look up the runner by its ref
883+
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
884+
updated_part_content = await runner.run(run_context)
885+
msg.parts[i] = _messages.SystemPromptPart(
886+
updated_part_content, dynamic_ref=part.dynamic_ref
887+
)
888+
838889
async def _prepare_messages(
839890
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
840891
) -> list[_messages.ModelMessage]:
@@ -850,8 +901,10 @@ async def _prepare_messages(
850901
ctx_messages.used = True
851902

852903
if message_history:
853-
# shallow copy messages
904+
# Shallow copy messages
854905
messages.extend(message_history)
906+
# Reevaluate any dynamic system prompt parts
907+
await self._reevaluate_dynamic_prompts(messages, run_context)
855908
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
856909
else:
857910
parts = await self._sys_parts(run_context)
@@ -1088,7 +1141,10 @@ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages
10881141
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
10891142
for sys_prompt_runner in self._system_prompt_functions:
10901143
prompt = await sys_prompt_runner.run(run_context)
1091-
messages.append(_messages.SystemPromptPart(prompt))
1144+
if sys_prompt_runner.dynamic:
1145+
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
1146+
else:
1147+
messages.append(_messages.SystemPromptPart(prompt))
10921148
return messages
10931149

10941150
def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ class SystemPromptPart:
2121
content: str
2222
"""The content of the prompt."""
2323

24+
dynamic_ref: str | None = None
25+
"""The ref of the dynamic system prompt function that generated this part.
26+
27+
Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
28+
"""
29+
2430
part_kind: Literal['system-prompt'] = 'system-prompt'
2531
"""Part type identifier, this is available on all parts as a discriminator."""
2632

tests/test_agent.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,6 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
12601260
assert result.data == 'success (no tool calls)'
12611261
result2 = agent.run_sync('Hello 2')
12621262
assert result2.data == 'success (no tool calls)'
1263-
12641263
assert messages == snapshot(
12651264
[
12661265
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
@@ -1269,6 +1268,149 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
12691268
)
12701269

12711270

1271+
def test_dynamic_false_no_reevaluate(set_event_loop: None):
1272+
"""When dynamic is false (default), the system prompt is not reevaluated
1273+
i.e: SystemPromptPart(
1274+
content="A", <--- Remains the same when `message_history` is passed.
1275+
part_kind='system-prompt')
1276+
"""
1277+
agent = Agent('test', system_prompt='Foobar')
1278+
1279+
dynamic_value = 'A'
1280+
1281+
@agent.system_prompt
1282+
async def func() -> str:
1283+
return dynamic_value
1284+
1285+
res = agent.run_sync('Hello')
1286+
1287+
assert res.all_messages() == snapshot(
1288+
[
1289+
ModelRequest(
1290+
parts=[
1291+
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1292+
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
1293+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1294+
],
1295+
kind='request',
1296+
),
1297+
ModelResponse(
1298+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1299+
timestamp=IsNow(tz=timezone.utc),
1300+
kind='response',
1301+
),
1302+
]
1303+
)
1304+
1305+
dynamic_value = 'B'
1306+
1307+
res_two = agent.run_sync('World', message_history=res.all_messages())
1308+
1309+
assert res_two.all_messages() == snapshot(
1310+
[
1311+
ModelRequest(
1312+
parts=[
1313+
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1314+
SystemPromptPart(
1315+
content='A', # Remains the same
1316+
part_kind='system-prompt',
1317+
),
1318+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1319+
],
1320+
kind='request',
1321+
),
1322+
ModelResponse(
1323+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1324+
timestamp=IsNow(tz=timezone.utc),
1325+
kind='response',
1326+
),
1327+
ModelRequest(
1328+
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1329+
kind='request',
1330+
),
1331+
ModelResponse(
1332+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1333+
timestamp=IsNow(tz=timezone.utc),
1334+
kind='response',
1335+
),
1336+
]
1337+
)
1338+
1339+
1340+
def test_dynamic_true_reevaluate_system_prompt(set_event_loop: None):
1341+
"""When dynamic is true, the system prompt is reevaluated
1342+
i.e: SystemPromptPart(
1343+
content="B", <--- Updated value
1344+
part_kind='system-prompt')
1345+
"""
1346+
agent = Agent('test', system_prompt='Foobar')
1347+
1348+
dynamic_value = 'A'
1349+
1350+
@agent.system_prompt(dynamic=True)
1351+
async def func():
1352+
return dynamic_value
1353+
1354+
res = agent.run_sync('Hello')
1355+
1356+
assert res.all_messages() == snapshot(
1357+
[
1358+
ModelRequest(
1359+
parts=[
1360+
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1361+
SystemPromptPart(
1362+
content=dynamic_value,
1363+
part_kind='system-prompt',
1364+
dynamic_ref=func.__qualname__,
1365+
),
1366+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1367+
],
1368+
kind='request',
1369+
),
1370+
ModelResponse(
1371+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1372+
timestamp=IsNow(tz=timezone.utc),
1373+
kind='response',
1374+
),
1375+
]
1376+
)
1377+
1378+
dynamic_value = 'B'
1379+
1380+
res_two = agent.run_sync('World', message_history=res.all_messages())
1381+
1382+
assert res_two.all_messages() == snapshot(
1383+
[
1384+
ModelRequest(
1385+
parts=[
1386+
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
1387+
SystemPromptPart(
1388+
content='B',
1389+
part_kind='system-prompt',
1390+
dynamic_ref=func.__qualname__,
1391+
),
1392+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
1393+
],
1394+
kind='request',
1395+
),
1396+
ModelResponse(
1397+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1398+
timestamp=IsNow(tz=timezone.utc),
1399+
kind='response',
1400+
),
1401+
ModelRequest(
1402+
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
1403+
kind='request',
1404+
),
1405+
ModelResponse(
1406+
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
1407+
timestamp=IsNow(tz=timezone.utc),
1408+
kind='response',
1409+
),
1410+
]
1411+
)
1412+
1413+
12721414
def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
12731415
agent_outer = Agent('test')
12741416
agent_inner = Agent(TestModel(custom_result_text='inner agent result'))

0 commit comments

Comments
 (0)