Skip to content

Commit 56b07ae

Browse files
HamzaFarhanKludex
andauthored
Allow multiple instructions and fix instruction concatenation in Agent (#1591)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent b1284b6 commit 56b07ae

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def __init__(
152152
model: models.Model | models.KnownModelName | str | None = None,
153153
*,
154154
output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,
155-
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
155+
instructions: str
156+
| _system_prompt.SystemPromptFunc[AgentDepsT]
157+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
158+
| None = None,
156159
system_prompt: str | Sequence[str] = (),
157160
deps_type: type[AgentDepsT] = NoneType,
158161
name: str | None = None,
@@ -175,7 +178,10 @@ def __init__(
175178
model: models.Model | models.KnownModelName | str | None = None,
176179
*,
177180
result_type: type[OutputDataT] = str,
178-
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
181+
instructions: str
182+
| _system_prompt.SystemPromptFunc[AgentDepsT]
183+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
184+
| None = None,
179185
system_prompt: str | Sequence[str] = (),
180186
deps_type: type[AgentDepsT] = NoneType,
181187
name: str | None = None,
@@ -197,7 +203,10 @@ def __init__(
197203
*,
198204
# TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads
199205
output_type: Any = str,
200-
instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | None = None,
206+
instructions: str
207+
| _system_prompt.SystemPromptFunc[AgentDepsT]
208+
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
209+
| None = None,
201210
system_prompt: str | Sequence[str] = (),
202211
deps_type: type[AgentDepsT] = NoneType,
203212
name: str | None = None,
@@ -296,10 +305,16 @@ def __init__(
296305
)
297306
self._output_validators = []
298307

299-
self._instructions_functions = (
300-
[_system_prompt.SystemPromptRunner(instructions)] if callable(instructions) else []
301-
)
302-
self._instructions = instructions if isinstance(instructions, str) else None
308+
self._instructions = ''
309+
self._instructions_functions = []
310+
if isinstance(instructions, (str, Callable)):
311+
instructions = [instructions]
312+
for instruction in instructions or []:
313+
if isinstance(instruction, str):
314+
self._instructions += instruction + '\n'
315+
else:
316+
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
317+
self._instructions = self._instructions.strip() or None
303318

304319
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
305320
self._system_prompt_functions = []
@@ -625,8 +640,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
625640

626641
instructions = self._instructions or ''
627642
for instructions_runner in self._instructions_functions:
628-
instructions += await instructions_runner.run(run_context)
629-
return instructions
643+
instructions += '\n' + await instructions_runner.run(run_context)
644+
return instructions.strip()
630645

631646
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
632647
user_deps=deps,

tests/test_agent.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pydantic_ai.result import Usage
3333
from pydantic_ai.tools import ToolDefinition
3434

35-
from .conftest import IsNow, IsStr, TestEnv
35+
from .conftest import IsDatetime, IsNow, IsStr, TestEnv
3636

3737
pytestmark = pytest.mark.anyio
3838

@@ -1816,6 +1816,23 @@ def test_instructions_with_message_history():
18161816
)
18171817

18181818

1819+
def test_instructions_parameter_with_sequence():
1820+
def instructions() -> str:
1821+
return 'You are a potato.'
1822+
1823+
agent = Agent('test', instructions=('You are a helpful assistant.', instructions))
1824+
result = agent.run_sync('Hello')
1825+
assert result.all_messages()[0] == snapshot(
1826+
ModelRequest(
1827+
parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())],
1828+
instructions="""\
1829+
You are a helpful assistant.
1830+
You are a potato.\
1831+
""",
1832+
)
1833+
)
1834+
1835+
18191836
def test_empty_final_response():
18201837
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
18211838
if len(messages) == 1:

0 commit comments

Comments
 (0)