66
66
from ..toolsets .combined import CombinedToolset
67
67
from ..toolsets .function import FunctionToolset
68
68
from ..toolsets .prepared import PreparedToolset
69
- from .abstract import AbstractAgent , EventStreamHandler , RunOutputDataT
69
+ from .abstract import AbstractAgent , EventStreamHandler , Instructions , RunOutputDataT
70
70
from .wrapper import WrapperAgent
71
71
72
72
if TYPE_CHECKING :
@@ -137,8 +137,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
137
137
_deps_type : type [AgentDepsT ] = dataclasses .field (repr = False )
138
138
_output_schema : _output .BaseOutputSchema [OutputDataT ] = dataclasses .field (repr = False )
139
139
_output_validators : list [_output .OutputValidator [AgentDepsT , OutputDataT ]] = dataclasses .field (repr = False )
140
- _instructions : str | None = dataclasses .field (repr = False )
141
- _instructions_functions : list [_system_prompt .SystemPromptRunner [AgentDepsT ]] = dataclasses .field (repr = False )
140
+ _instructions : list [str | _system_prompt .SystemPromptFunc [AgentDepsT ]] = dataclasses .field (repr = False )
142
141
_system_prompts : tuple [str , ...] = dataclasses .field (repr = False )
143
142
_system_prompt_functions : list [_system_prompt .SystemPromptRunner [AgentDepsT ]] = dataclasses .field (repr = False )
144
143
_system_prompt_dynamic_functions : dict [str , _system_prompt .SystemPromptRunner [AgentDepsT ]] = dataclasses .field (
@@ -164,10 +163,7 @@ def __init__(
164
163
model : models .Model | models .KnownModelName | str | None = None ,
165
164
* ,
166
165
output_type : OutputSpec [OutputDataT ] = str ,
167
- instructions : str
168
- | _system_prompt .SystemPromptFunc [AgentDepsT ]
169
- | Sequence [str | _system_prompt .SystemPromptFunc [AgentDepsT ]]
170
- | None = None ,
166
+ instructions : Instructions [AgentDepsT ] = None ,
171
167
system_prompt : str | Sequence [str ] = (),
172
168
deps_type : type [AgentDepsT ] = NoneType ,
173
169
name : str | None = None ,
@@ -193,10 +189,7 @@ def __init__(
193
189
model : models .Model | models .KnownModelName | str | None = None ,
194
190
* ,
195
191
output_type : OutputSpec [OutputDataT ] = str ,
196
- instructions : str
197
- | _system_prompt .SystemPromptFunc [AgentDepsT ]
198
- | Sequence [str | _system_prompt .SystemPromptFunc [AgentDepsT ]]
199
- | None = None ,
192
+ instructions : Instructions [AgentDepsT ] = None ,
200
193
system_prompt : str | Sequence [str ] = (),
201
194
deps_type : type [AgentDepsT ] = NoneType ,
202
195
name : str | None = None ,
@@ -220,10 +213,7 @@ def __init__(
220
213
model : models .Model | models .KnownModelName | str | None = None ,
221
214
* ,
222
215
output_type : OutputSpec [OutputDataT ] = str ,
223
- instructions : str
224
- | _system_prompt .SystemPromptFunc [AgentDepsT ]
225
- | Sequence [str | _system_prompt .SystemPromptFunc [AgentDepsT ]]
226
- | None = None ,
216
+ instructions : Instructions [AgentDepsT ] = None ,
227
217
system_prompt : str | Sequence [str ] = (),
228
218
deps_type : type [AgentDepsT ] = NoneType ,
229
219
name : str | None = None ,
@@ -322,16 +312,7 @@ def __init__(
322
312
self ._output_schema = _output .OutputSchema [OutputDataT ].build (output_type , default_mode = default_output_mode )
323
313
self ._output_validators = []
324
314
325
- self ._instructions = ''
326
- self ._instructions_functions = []
327
- if isinstance (instructions , str | Callable ):
328
- instructions = [instructions ]
329
- for instruction in instructions or []:
330
- if isinstance (instruction , str ):
331
- self ._instructions += instruction + '\n '
332
- else :
333
- self ._instructions_functions .append (_system_prompt .SystemPromptRunner (instruction ))
334
- self ._instructions = self ._instructions .strip () or None
315
+ self ._instructions = self ._normalize_instructions (instructions )
335
316
336
317
self ._system_prompts = (system_prompt ,) if isinstance (system_prompt , str ) else tuple (system_prompt )
337
318
self ._system_prompt_functions = []
@@ -371,6 +352,9 @@ def __init__(
371
352
self ._override_tools : ContextVar [
372
353
_utils .Option [Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]]]
373
354
] = ContextVar ('_override_tools' , default = None )
355
+ self ._override_instructions : ContextVar [
356
+ _utils .Option [list [str | _system_prompt .SystemPromptFunc [AgentDepsT ]]]
357
+ ] = ContextVar ('_override_instructions' , default = None )
374
358
375
359
self ._enter_lock = Lock ()
376
360
self ._entered_count = 0
@@ -593,10 +577,12 @@ async def main():
593
577
model_settings = merge_model_settings (merged_settings , model_settings )
594
578
usage_limits = usage_limits or _usage .UsageLimits ()
595
579
580
+ instructions_literal , instructions_functions = self ._get_instructions ()
581
+
596
582
async def get_instructions (run_context : RunContext [AgentDepsT ]) -> str | None :
597
583
parts = [
598
- self . _instructions ,
599
- * [await func .run (run_context ) for func in self . _instructions_functions ],
584
+ instructions_literal ,
585
+ * [await func .run (run_context ) for func in instructions_functions ],
600
586
]
601
587
602
588
model_profile = model_used .profile
@@ -634,11 +620,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
634
620
get_instructions = get_instructions ,
635
621
instrumentation_settings = instrumentation_settings ,
636
622
)
623
+
637
624
start_node = _agent_graph .UserPromptNode [AgentDepsT ](
638
625
user_prompt = user_prompt ,
639
626
deferred_tool_results = deferred_tool_results ,
640
- instructions = self . _instructions ,
641
- instructions_functions = self . _instructions_functions ,
627
+ instructions = instructions_literal ,
628
+ instructions_functions = instructions_functions ,
642
629
system_prompts = self ._system_prompts ,
643
630
system_prompt_functions = self ._system_prompt_functions ,
644
631
system_prompt_dynamic_functions = self ._system_prompt_dynamic_functions ,
@@ -690,6 +677,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
690
677
def _run_span_end_attributes (
691
678
self , state : _agent_graph .GraphAgentState , usage : _usage .RunUsage , settings : InstrumentationSettings
692
679
):
680
+ literal_instructions , _ = self ._get_instructions ()
681
+
693
682
if settings .version == 1 :
694
683
attrs = {
695
684
'all_messages_events' : json .dumps (
@@ -702,7 +691,7 @@ def _run_span_end_attributes(
702
691
else :
703
692
attrs = {
704
693
'pydantic_ai.all_messages' : json .dumps (settings .messages_to_otel_messages (state .message_history )),
705
- ** settings .system_instructions_attributes (self . _instructions ),
694
+ ** settings .system_instructions_attributes (literal_instructions ),
706
695
}
707
696
708
697
return {
@@ -727,8 +716,9 @@ def override(
727
716
model : models .Model | models .KnownModelName | str | _utils .Unset = _utils .UNSET ,
728
717
toolsets : Sequence [AbstractToolset [AgentDepsT ]] | _utils .Unset = _utils .UNSET ,
729
718
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] | _utils .Unset = _utils .UNSET ,
719
+ instructions : Instructions [AgentDepsT ] | _utils .Unset = _utils .UNSET ,
730
720
) -> Iterator [None ]:
731
- """Context manager to temporarily override agent dependencies, model, toolsets, or tools .
721
+ """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions .
732
722
733
723
This is particularly useful when testing.
734
724
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
@@ -738,6 +728,7 @@ def override(
738
728
model: The model to use instead of the model passed to the agent run.
739
729
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
740
730
tools: The tools to use instead of the tools registered with the agent.
731
+ instructions: The instructions to use instead of the instructions registered with the agent.
741
732
"""
742
733
if _utils .is_set (deps ):
743
734
deps_token = self ._override_deps .set (_utils .Some (deps ))
@@ -759,6 +750,12 @@ def override(
759
750
else :
760
751
tools_token = None
761
752
753
+ if _utils .is_set (instructions ):
754
+ normalized_instructions = self ._normalize_instructions (instructions )
755
+ instructions_token = self ._override_instructions .set (_utils .Some (normalized_instructions ))
756
+ else :
757
+ instructions_token = None
758
+
762
759
try :
763
760
yield
764
761
finally :
@@ -770,6 +767,8 @@ def override(
770
767
self ._override_toolsets .reset (toolsets_token )
771
768
if tools_token is not None :
772
769
self ._override_tools .reset (tools_token )
770
+ if instructions_token is not None :
771
+ self ._override_instructions .reset (instructions_token )
773
772
774
773
@overload
775
774
def instructions (
@@ -830,12 +829,12 @@ async def async_instructions(ctx: RunContext[str]) -> str:
830
829
def decorator (
831
830
func_ : _system_prompt .SystemPromptFunc [AgentDepsT ],
832
831
) -> _system_prompt .SystemPromptFunc [AgentDepsT ]:
833
- self ._instructions_functions .append (_system_prompt . SystemPromptRunner ( func_ ) )
832
+ self ._instructions .append (func_ )
834
833
return func_
835
834
836
835
return decorator
837
836
else :
838
- self ._instructions_functions .append (_system_prompt . SystemPromptRunner ( func ) )
837
+ self ._instructions .append (func )
839
838
return func
840
839
841
840
@overload
@@ -1276,6 +1275,34 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
1276
1275
else :
1277
1276
return deps
1278
1277
1278
+ def _normalize_instructions (
1279
+ self ,
1280
+ instructions : Instructions [AgentDepsT ],
1281
+ ) -> list [str | _system_prompt .SystemPromptFunc [AgentDepsT ]]:
1282
+ if instructions is None :
1283
+ return []
1284
+ if isinstance (instructions , str ) or callable (instructions ):
1285
+ return [instructions ]
1286
+ return list (instructions )
1287
+
1288
+ def _get_instructions (
1289
+ self ,
1290
+ ) -> tuple [str | None , list [_system_prompt .SystemPromptRunner [AgentDepsT ]]]:
1291
+ override_instructions = self ._override_instructions .get ()
1292
+ instructions = override_instructions .value if override_instructions else self ._instructions
1293
+
1294
+ literal_parts : list [str ] = []
1295
+ functions : list [_system_prompt .SystemPromptRunner [AgentDepsT ]] = []
1296
+
1297
+ for instruction in instructions :
1298
+ if isinstance (instruction , str ):
1299
+ literal_parts .append (instruction )
1300
+ else :
1301
+ functions .append (_system_prompt .SystemPromptRunner [AgentDepsT ](instruction ))
1302
+
1303
+ literal = '\n ' .join (literal_parts ).strip () or None
1304
+ return literal , functions
1305
+
1279
1306
def _get_toolset (
1280
1307
self ,
1281
1308
output_toolset : AbstractToolset [AgentDepsT ] | None | _utils .Unset = _utils .UNSET ,
0 commit comments