diff --git a/docs/ja/usage.md b/docs/ja/usage.md index a45385ce7..81948b5e5 100644 --- a/docs/ja/usage.md +++ b/docs/ja/usage.md @@ -89,6 +89,37 @@ class MyHooks(RunHooks): print(f"{agent.name} → {u.requests} requests, {u.total_tokens} total tokens") ``` +## フックから会話履歴を編集する + +`RunContextWrapper` には `message_history` も含まれており、フックから会話を直接読み書きできます。 + +- `get_messages()` は元の入力・モデル出力・保留中の挿入を含む完全な履歴を `ResponseInputItem` のリストとして返します。 +- `add_message(agent=..., message=...)` は任意のメッセージ(文字列、辞書、または `ResponseInputItem` のリスト)をキューに追加します。追加されたメッセージは即座に LLM への入力に連結され、実行結果やストリームイベントでは `InjectedInputItem` として公開されます。 +- `override_next_turn(messages)` は次の LLM 呼び出しに送信される履歴全体を置き換えます。ガードレールや外部レビュー後に履歴を書き換えたい場合に使用できます。 + +```python +class BroadcastHooks(RunHooks): + def __init__(self, reviewer_name: str): + self.reviewer_name = reviewer_name + + async def on_llm_start( + self, + context: RunContextWrapper, + agent: Agent, + _instructions: str | None, + _input_items: list[TResponseInputItem], + ) -> None: + context.message_history.add_message( + agent=agent, + message={ + "role": "user", + "content": f"{self.reviewer_name}: 先に付録のデータを引用してください。", + }, + ) +``` + +> **注意:** `conversation_id` または `previous_response_id` を指定して実行している場合、履歴はサーバー側のスレッドで管理されるため、そのランでは `message_history.override_next_turn()` を使用できません。 + ## API リファレンス 詳細な API ドキュメントは次を参照してください: @@ -96,4 +127,5 @@ class MyHooks(RunHooks): - [`Usage`][agents.usage.Usage] - 使用状況トラッキングのデータ構造 - [`RequestUsage`][agents.usage.RequestUsage] - リクエストごとの使用状況の詳細 - [`RunContextWrapper`][agents.run.RunContextWrapper] - 実行コンテキストから使用状況にアクセス +- [`MessageHistory`][agents.run_context.MessageHistory] - フックから会話履歴を閲覧・編集 - [`RunHooks`][agents.run.RunHooks] - 使用状況トラッキングのライフサイクルにフック \ No newline at end of file diff --git a/docs/ko/usage.md b/docs/ko/usage.md index b153ce468..f2d93ed2d 100644 --- a/docs/ko/usage.md +++ b/docs/ko/usage.md @@ -89,6 +89,37 @@ class MyHooks(RunHooks): print(f"{agent.name} → {u.requests} requests, {u.total_tokens} total tokens") ``` +## 훅에서 대화 기록 수정하기 + +`RunContextWrapper`에는 `message_history`도 포함되어 있어 훅에서 현재 대화를 바로 읽거나 수정할 수 있습니다. + +- `get_messages()`는 원본 입력, 모델 출력, 보류 중인 삽입 항목을 모두 포함한 전체 기록을 `ResponseInputItem` 리스트로 반환합니다. +- `add_message(agent=..., message=...)`는 사용자 정의 메시지(문자열, 딕셔너리 또는 `ResponseInputItem` 리스트)를 큐에 추가합니다. 추가된 메시지는 즉시 LLM 입력에 이어 붙고 실행 결과/스트림 이벤트에서는 `InjectedInputItem`으로 노출됩니다. +- `override_next_turn(messages)`는 다음 LLM 호출의 입력 전체를 교체할 때 사용합니다. 가드레일이나 외부 검토 결과에 따라 히스토리를 다시 작성해야 할 때 유용합니다. + +```python +class BroadcastHooks(RunHooks): + def __init__(self, reviewer_name: str): + self.reviewer_name = reviewer_name + + async def on_llm_start( + self, + context: RunContextWrapper, + agent: Agent, + _instructions: str | None, + _input_items: list[TResponseInputItem], + ) -> None: + context.message_history.add_message( + agent=agent, + message={ + "role": "user", + "content": f"{self.reviewer_name}: 답변 전에 부록 데이터를 인용하세요.", + }, + ) +``` + +> **참고:** `conversation_id` 또는 `previous_response_id`와 함께 실행하는 경우 서버 측 대화 스레드가 입력을 관리하므로 해당 런에서는 `message_history.override_next_turn()`을 사용할 수 없습니다. + ## API 레퍼런스 자세한 API 문서는 다음을 참조하세요: @@ -96,4 +127,5 @@ class MyHooks(RunHooks): - [`Usage`][agents.usage.Usage] - 사용량 추적 데이터 구조 - [`RequestUsage`][agents.usage.RequestUsage] - 요청별 사용량 세부 정보 - [`RunContextWrapper`][agents.run.RunContextWrapper] - 실행 컨텍스트에서 사용량 접근 +- [`MessageHistory`][agents.run_context.MessageHistory] - 훅에서 대화 기록을 조회/편집 - [`RunHooks`][agents.run.RunHooks] - 사용량 추적 수명 주기에 훅 연결 \ No newline at end of file diff --git a/docs/usage.md b/docs/usage.md index bedae99b3..7cf7db679 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -85,6 +85,44 @@ class MyHooks(RunHooks): print(f"{agent.name} → {u.requests} requests, {u.total_tokens} total tokens") ``` +## Modifying chat history in hooks + +`RunContextWrapper` also exposes `message_history`, giving hooks a mutable view of the +conversation: + +- `get_messages()` returns the full transcript (original input, model outputs, and pending + injections) as a list of `ResponseInputItem` dictionaries. +- `add_message(agent=..., message=...)` queues custom messages (string, dict, or list of + `ResponseInputItem`s). Pending messages are appended to the current LLM input immediately and are + emitted as `InjectedInputItem`s in the run result or stream events. +- `override_next_turn(messages)` replaces the entire input for the upcoming LLM call. Use this to + rewrite history after a guardrail or external reviewer intervenes. + +```python +class BroadcastHooks(RunHooks): + def __init__(self, reviewer_name: str): + self.reviewer_name = reviewer_name + + async def on_llm_start( + self, + context: RunContextWrapper, + agent: Agent, + _instructions: str | None, + _input_items: list[TResponseInputItem], + ) -> None: + context.message_history.add_message( + agent=agent, + message={ + "role": "user", + "content": f"{self.reviewer_name}: Please cite the appendix before answering.", + }, + ) +``` + +> **Note:** When running with `conversation_id` or `previous_response_id`, overrides are managed by +> the server-side conversation thread and `message_history.override_next_turn()` is disabled for +> that run. + ## API Reference For detailed API documentation, see: @@ -92,4 +130,5 @@ For detailed API documentation, see: - [`Usage`][agents.usage.Usage] - Usage tracking data structure - [`RequestUsage`][agents.usage.RequestUsage] - Per-request usage details - [`RunContextWrapper`][agents.run.RunContextWrapper] - Access usage from run context +- [`MessageHistory`][agents.run_context.MessageHistory] - Inspect or edit the conversation from hooks - [`RunHooks`][agents.run.RunHooks] - Hook into usage tracking lifecycle \ No newline at end of file diff --git a/docs/zh/usage.md b/docs/zh/usage.md index 990071f57..7a50e262a 100644 --- a/docs/zh/usage.md +++ b/docs/zh/usage.md @@ -89,6 +89,37 @@ class MyHooks(RunHooks): print(f"{agent.name} → {u.requests} requests, {u.total_tokens} total tokens") ``` +## 在钩子中修改对话历史 + +`RunContextWrapper` 还暴露了 `message_history`,允许钩子直接读取或修改当前会话: + +- `get_messages()` 以 `ResponseInputItem` 列表的形式返回完整对话(原始输入、模型输出以及所有待插入消息)。 +- `add_message(agent=..., message=...)` 将自定义消息(字符串、字典或 `ResponseInputItem` 列表)加入队列。消息会立即追加到本次 LLM 调用的输入,并作为 `InjectedInputItem` 出现在运行结果或流式事件中。 +- `override_next_turn(messages)` 用自定义内容完全替换下一次 LLM 调用的输入,适用于在守护程序或人工审核后重写上下文的场景。 + +```python +class BroadcastHooks(RunHooks): + def __init__(self, reviewer_name: str): + self.reviewer_name = reviewer_name + + async def on_llm_start( + self, + context: RunContextWrapper, + agent: Agent, + _instructions: str | None, + _input_items: list[TResponseInputItem], + ) -> None: + context.message_history.add_message( + agent=agent, + message={ + "role": "user", + "content": f"{self.reviewer_name}: 回答前请先引用附录中的数据。", + }, + ) +``` + +> **注意:** 当运行时指定了 `conversation_id` 或 `previous_response_id` 时,会话由服务器端线程维护,此时无法调用 `message_history.override_next_turn()`。 + ## API 参考 如需详细的 API 文档,请参见: @@ -96,4 +127,5 @@ class MyHooks(RunHooks): - [`Usage`][agents.usage.Usage] - 用量跟踪数据结构 - [`RequestUsage`][agents.usage.RequestUsage] - 按请求的用量详情 - [`RunContextWrapper`][agents.run.RunContextWrapper] - 从运行上下文访问用量 +- [`MessageHistory`][agents.run_context.MessageHistory] - 在钩子中查看或编辑对话 - [`RunHooks`][agents.run.RunHooks] - 接入用量跟踪的生命周期 \ No newline at end of file diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6f4d0815d..30dc4a408 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -50,6 +50,7 @@ from .items import ( HandoffCallItem, HandoffOutputItem, + InjectedInputItem, ItemHelpers, MessageOutputItem, ModelResponse, @@ -66,6 +67,7 @@ SessionABC, SQLiteSession, ) +from .message_history import MessageHistory from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.multi_provider import MultiProvider @@ -276,6 +278,7 @@ def enable_verbose_stdout_logging(): "RunItem", "HandoffCallItem", "HandoffOutputItem", + "InjectedInputItem", "ToolCallItem", "ToolCallOutputItem", "ReasoningItem", @@ -287,6 +290,7 @@ def enable_verbose_stdout_logging(): "SQLiteSession", "OpenAIConversationsSession", "RunContextWrapper", + "MessageHistory", "TContext", "RunErrorDetails", "RunResult", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 3f3f2b916..3bddf280f 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -59,6 +59,7 @@ from .items import ( HandoffCallItem, HandoffOutputItem, + InjectedInputItem, ItemHelpers, MCPApprovalRequestItem, MCPApprovalResponseItem, @@ -891,14 +892,20 @@ async def _execute_tool_with_hooks( Returns: The result from the tool execution. """ - await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), - ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) - if agent.hooks - else _coro.noop_coroutine() - ), + marker = tool_context.message_history.begin_injection_stage( + "before_tool", tool_call.call_id ) + try: + await asyncio.gather( + hooks.on_tool_start(tool_context, agent, func_tool), + ( + agent.hooks.on_tool_start(tool_context, agent, func_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + tool_context.message_history.end_injection_stage(marker) return await func_tool.on_invoke_tool(tool_context, tool_call.arguments) @@ -961,16 +968,22 @@ async def run_single_tool( ) # 4) Tool end hooks (with final result, which may have been overridden) - await asyncio.gather( - hooks.on_tool_end(tool_context, agent, func_tool, final_result), - ( - agent.hooks.on_tool_end( - tool_context, agent, func_tool, final_result - ) - if agent.hooks - else _coro.noop_coroutine() - ), + end_marker = tool_context.message_history.begin_injection_stage( + "after_tool", tool_call.call_id ) + try: + await asyncio.gather( + hooks.on_tool_end(tool_context, agent, func_tool, final_result), + ( + agent.hooks.on_tool_end( + tool_context, agent, func_tool, final_result + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + tool_context.message_history.end_injection_stage(end_marker) result = final_result except Exception as e: _error_tracing.attach_error_to_current_span( @@ -1401,7 +1414,7 @@ def stream_step_items_to_queue( queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel], ): for item in new_step_items: - if isinstance(item, MessageOutputItem): + if isinstance(item, MessageOutputItem) or isinstance(item, InjectedInputItem): event = RunItemStreamEvent(item=item, name="message_output_created") elif isinstance(item, HandoffCallItem): event = RunItemStreamEvent(item=item, name="handoff_requested") @@ -1535,24 +1548,37 @@ async def execute( else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call) ) - _, _, output = await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, action.computer_tool), - ( - agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - output_func, + action_call_id = _get_tool_call_id(action.tool_call) + marker = context_wrapper.message_history.begin_injection_stage( + "before_tool", action_call_id ) + try: + _, _, output = await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, action.computer_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + output_func, + ) + finally: + context_wrapper.message_history.end_injection_stage(marker) - await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), - ( - agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) - if agent.hooks - else _coro.noop_coroutine() - ), + end_marker = context_wrapper.message_history.begin_injection_stage( + "after_tool", action_call_id ) + try: + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), + ( + agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(end_marker) # TODO: don't send a screenshot every single time, use references image_url = f"data:image/png;base64,{output}" @@ -1638,14 +1664,19 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: - await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), - ( - agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + call_id = _get_tool_call_id(call.tool_call) + marker = context_wrapper.message_history.begin_injection_stage("before_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(marker) request = LocalShellCommandRequest( ctx_wrapper=context_wrapper, @@ -1657,18 +1688,22 @@ async def execute( else: result = output - await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), - ( - agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + end_marker = context_wrapper.message_history.begin_injection_stage("after_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(end_marker) raw_payload: dict[str, Any] = { "type": "local_shell_call_output", - "call_id": call.tool_call.call_id, + "call_id": call_id, "output": result, } return ToolCallOutputItem( @@ -1689,14 +1724,19 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: - await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.shell_tool), - ( - agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + call_id = _get_tool_call_id(call.tool_call) + marker = context_wrapper.message_history.begin_injection_stage("before_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(marker) shell_call = _coerce_shell_call(call.tool_call) request = ShellCommandRequest(ctx_wrapper=context_wrapper, data=shell_call) @@ -1725,14 +1765,18 @@ async def execute( output_text = _format_shell_error(exc) logger.error("Shell executor failed: %s", exc, exc_info=True) - await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), - ( - agent.hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + end_marker = context_wrapper.message_history.begin_injection_stage("after_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(end_marker) raw_entries: list[dict[str, Any]] | None = None if shell_output_payload: @@ -1813,14 +1857,19 @@ async def execute( config: RunConfig, ) -> RunItem: apply_patch_tool = call.apply_patch_tool - await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), - ( - agent.hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + call_id = _extract_apply_patch_call_id(call.tool_call) + marker = context_wrapper.message_history.begin_injection_stage("before_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(marker) status: Literal["completed", "failed"] = "completed" output_text = "" @@ -1849,18 +1898,22 @@ async def execute( output_text = _format_shell_error(exc) logger.error("Apply patch editor failed: %s", exc, exc_info=True) - await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), - ( - agent.hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + end_marker = context_wrapper.message_history.begin_injection_stage("after_tool", call_id) + try: + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), + ( + agent.hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + context_wrapper.message_history.end_injection_stage(end_marker) raw_item: dict[str, Any] = { "type": "apply_patch_call_output", - "call_id": _extract_apply_patch_call_id(call.tool_call), + "call_id": call_id, "status": status, } if output_text: @@ -1960,6 +2013,12 @@ def _normalize_exit_code(value: Any) -> int | None: return None +def _get_tool_call_id(tool_call: Any) -> str | None: + if isinstance(tool_call, Mapping): + return cast(str | None, tool_call.get("call_id")) + return cast(str | None, getattr(tool_call, "call_id", None)) + + def _render_shell_outputs(outputs: Sequence[ShellCommandOutput]) -> str: if not outputs: return "(no output)" diff --git a/src/agents/items.py b/src/agents/items.py index 991a7f877..6f0ddab1f 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -153,6 +153,16 @@ class MessageOutputItem(RunItemBase[ResponseOutputMessage]): type: Literal["message_output_item"] = "message_output_item" +@dataclass +class InjectedInputItem(RunItemBase[TResponseInputItem]): + """Represents a manually injected input item added via hooks.""" + + raw_item: TResponseInputItem + """The injected input item that should be treated as part of the conversation.""" + + type: Literal["injected_input_item"] = "injected_input_item" + + @dataclass class HandoffCallItem(RunItemBase[ResponseFunctionToolCall]): """Represents a tool call for a handoff from one agent to another.""" @@ -329,6 +339,7 @@ class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): RunItem: TypeAlias = Union[ MessageOutputItem, + InjectedInputItem, HandoffCallItem, HandoffOutputItem, ToolCallItem, diff --git a/src/agents/message_history.py b/src/agents/message_history.py new file mode 100644 index 000000000..0bb0631c5 --- /dev/null +++ b/src/agents/message_history.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence as ABCSequence +from copy import deepcopy +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, cast + +if TYPE_CHECKING: + from .agent import Agent + from .items import InjectedInputItem, RunItem, TResponseInputItem +else: # pragma: no cover - runtime fallbacks to break import cycles + Agent = Any # type: ignore[assignment] + InjectedInputItem = RunItem = TResponseInputItem = Any # type: ignore[assignment] + + +def _input_to_new_input_list( + value: str | TResponseInputItem | ABCSequence[TResponseInputItem], +) -> list[TResponseInputItem]: + from .items import ItemHelpers + + if isinstance(value, str): + return ItemHelpers.input_to_new_input_list(value) + if isinstance(value, list): + return ItemHelpers.input_to_new_input_list(value) + if isinstance(value, ABCSequence): + sequence_value = cast(Iterable[TResponseInputItem], value) + return ItemHelpers.input_to_new_input_list(list(sequence_value)) + return ItemHelpers.input_to_new_input_list([value]) + + +InjectedMessageStage = Literal[ + "agent_start", + "before_llm", + "after_llm", + "before_tool", + "after_tool", + "unspecified", +] + + +@dataclass +class _StageMarker: + stage: InjectedMessageStage + call_id: str | None + + +@dataclass +class InjectedMessageRecord: + item: InjectedInputItem + stage: InjectedMessageStage + call_id: str | None + order: int + + +@dataclass +class MessageHistory: + """Tracks the conversation history visible to hooks and allows modifications.""" + + _original_input: str | list[TResponseInputItem] | None = None + _generated_items: list[RunItem] | None = None + _pending_injected_items: list[InjectedMessageRecord] = field(default_factory=list) + _next_turn_override: list[TResponseInputItem] | None = None + _live_input_buffer: list[TResponseInputItem] | None = field(default=None, repr=False) + _stage_markers: list[_StageMarker] = field(default_factory=list, repr=False) + _next_order: int = 0 + + def set_original_input(self, original_input: str | list[TResponseInputItem]) -> None: + """Update the original input reference for the current run.""" + + self._original_input = original_input + + def bind_generated_items(self, generated_items: list[RunItem]) -> None: + """Bind the list of generated items accumulated so far.""" + + self._generated_items = generated_items + + def get_messages(self) -> list[TResponseInputItem]: + """Return a snapshot of the current transcript, including pending injections.""" + + messages: list[TResponseInputItem] = [] + if self._original_input is not None: + messages.extend(_input_to_new_input_list(self._original_input)) + if self._generated_items: + messages.extend(item.to_input_item() for item in self._generated_items) + if self._pending_injected_items: + messages.extend(record.item.to_input_item() for record in self._pending_injected_items) + return messages + + def add_message( + self, + *, + agent: Agent[Any], + message: str | TResponseInputItem | ABCSequence[TResponseInputItem], + ) -> None: + """Queue one or more messages to be appended to the history.""" + + from .items import InjectedInputItem + + new_items = _input_to_new_input_list(message) + for item in new_items: + normalized = deepcopy(item) + injected_item = InjectedInputItem(agent=agent, raw_item=normalized) + stage = self._stage_markers[-1].stage if self._stage_markers else "unspecified" + call_id = self._stage_markers[-1].call_id if self._stage_markers else None + self._pending_injected_items.append( + InjectedMessageRecord( + item=injected_item, + stage=stage, + call_id=call_id, + order=self._next_order, + ) + ) + self._next_order += 1 + if self._live_input_buffer is not None: + self._live_input_buffer.append(deepcopy(normalized)) + + def pending_input_items(self) -> list[TResponseInputItem]: + """Return pending injected messages as input items without clearing them.""" + + return [record.item.to_input_item() for record in self._pending_injected_items] + + def flush_pending_items(self) -> list[InjectedMessageRecord]: + """Return and clear pending injected messages with metadata.""" + + pending = self._pending_injected_items + self._pending_injected_items = [] + return pending + + def override_next_turn(self, messages: ABCSequence[TResponseInputItem]) -> None: + """Replace the next model call's input history with a custom list.""" + + override_messages = [deepcopy(message) for message in messages] + self._next_turn_override = override_messages + if self._live_input_buffer is not None: + self._live_input_buffer.clear() + self._live_input_buffer.extend(deepcopy(message) for message in override_messages) + + def consume_next_turn_override(self) -> list[TResponseInputItem] | None: + """Pop the next turn override if set.""" + + if self._next_turn_override is None: + return None + override = [deepcopy(message) for message in self._next_turn_override] + self._next_turn_override = None + return override + + def clear(self) -> None: + """Reset all pending state.""" + + self._pending_injected_items.clear() + self._next_turn_override = None + self._live_input_buffer = None + + def bind_live_input_buffer(self, buffer: list[TResponseInputItem]) -> None: + """Bind the live model input list so hook mutations apply immediately.""" + + self._live_input_buffer = buffer + + def release_live_input_buffer(self) -> None: + """Stop tracking the live model input once the LLM call completes.""" + + self._live_input_buffer = None + + def begin_injection_stage( + self, stage: InjectedMessageStage, call_id: str | None = None + ) -> _StageMarker: + marker = _StageMarker(stage=stage, call_id=call_id) + self._stage_markers.append(marker) + return marker + + def end_injection_stage(self, marker: _StageMarker | None) -> None: + if marker is None: + return + if not self._stage_markers: + return + if self._stage_markers and self._stage_markers[-1] is marker: + self._stage_markers.pop() + return + try: + self._stage_markers.remove(marker) + except ValueError: + pass diff --git a/src/agents/run.py b/src/agents/run.py index fce7b4840..ca50c8c3a 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -55,11 +55,13 @@ RunItem, ToolCallItem, ToolCallItemTypes, + ToolCallOutputItem, TResponseInputItem, ) from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase from .logger import logger from .memory import Session, SessionInputCallback +from .message_history import InjectedMessageRecord from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -171,6 +173,55 @@ def prepare_input( return input_items +def _extract_call_id_from_tool_output(item: ToolCallOutputItem) -> str | None: + raw_item = item.raw_item + if isinstance(raw_item, dict): + return cast(str | None, raw_item.get("call_id")) + return cast(str | None, getattr(raw_item, "call_id", None)) + + +def _insert_injected_items( + target_items: list[RunItem], injected_records: list[InjectedMessageRecord] +) -> None: + if not injected_records: + return + + front_insertions = 0 + for record in sorted(injected_records, key=lambda r: r.order): + stage = record.stage + if stage in ("agent_start", "before_llm"): + insert_at = front_insertions + target_items.insert(insert_at, record.item) + front_insertions += 1 + continue + + if stage in ("after_llm", "unspecified"): + target_items.append(record.item) + continue + + if stage in ("before_tool", "after_tool") and record.call_id: + tool_index = None + for idx, item in enumerate(target_items): + if isinstance(item, ToolCallOutputItem): + call_id = _extract_call_id_from_tool_output(item) + if call_id == record.call_id: + tool_index = idx + break + if tool_index is None: + target_items.append(record.item) + continue + if stage == "before_tool": + target_items.insert(tool_index, record.item) + if tool_index <= front_insertions: + front_insertions += 1 + else: + target_items.insert(tool_index + 1, record.item) + continue + + # Fallback when we don't have enough metadata + target_items.append(record.item) + + # Type alias for the optional input filter callback CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] @@ -562,6 +613,8 @@ async def run( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context, # type: ignore ) + context_wrapper.message_history.set_original_input(original_input) + context_wrapper.message_history.bind_generated_items(generated_items) input_guardrail_results: list[InputGuardrailResult] = [] tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] @@ -676,6 +729,8 @@ async def run( model_responses.append(turn_result.model_response) original_input = turn_result.original_input generated_items = turn_result.generated_items + context_wrapper.message_history.set_original_input(original_input) + context_wrapper.message_history.bind_generated_items(generated_items) if server_conversation_tracker is not None: server_conversation_tracker.track_server_items(turn_result.model_response) @@ -886,6 +941,8 @@ def run_streamed( trace=new_trace, context_wrapper=context_wrapper, ) + context_wrapper.message_history.set_original_input(streamed_result.input) + context_wrapper.message_history.bind_generated_items(streamed_result.new_items) # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( @@ -1055,6 +1112,8 @@ async def _start_streaming( # Update the streamed result with the prepared input streamed_result.input = prepared_input + context_wrapper.message_history.set_original_input(streamed_result.input) + context_wrapper.message_history.bind_generated_items(streamed_result.new_items) await AgentRunner._save_result_to_session(session, starting_input, []) @@ -1160,6 +1219,8 @@ async def _start_streaming( ] streamed_result.input = turn_result.original_input streamed_result.new_items = turn_result.generated_items + context_wrapper.message_history.set_original_input(streamed_result.input) + context_wrapper.message_history.bind_generated_items(streamed_result.new_items) if server_conversation_tracker is not None: server_conversation_tracker.track_server_items(turn_result.model_response) @@ -1300,16 +1361,21 @@ async def _run_single_turn_streamed( ) -> SingleStepResult: emitted_tool_call_ids: set[str] = set() emitted_reasoning_item_ids: set[str] = set() + history = context_wrapper.message_history if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + marker = history.begin_injection_stage("agent_start") + try: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + history.end_injection_stage(marker) output_schema = cls._get_output_schema(agent) @@ -1328,13 +1394,30 @@ async def _run_single_turn_streamed( final_response: ModelResponse | None = None + history.set_original_input(streamed_result.input) + history.bind_generated_items(streamed_result.new_items) + + override_messages = history.consume_next_turn_override() + pending_input_items = history.pending_input_items() + if server_conversation_tracker is not None: + if override_messages is not None: + raise UserError( + "message_history overrides are not supported when using conversation_id or " + "previous_response_id." + ) input = server_conversation_tracker.prepare_input( streamed_result.input, streamed_result.new_items ) else: - input = ItemHelpers.input_to_new_input_list(streamed_result.input) - input.extend([item.to_input_item() for item in streamed_result.new_items]) + if override_messages is not None: + input = override_messages + else: + input = ItemHelpers.input_to_new_input_list(streamed_result.input) + input.extend([item.to_input_item() for item in streamed_result.new_items]) + + if pending_input_items: + input.extend(pending_input_items) # THIS IS THE RESOLVED CONFLICT BLOCK filtered = await cls._maybe_filter_model_input( @@ -1345,105 +1428,121 @@ async def _run_single_turn_streamed( system_instructions=system_prompt, ) - # Call hook just before the model is invoked, with the correct system_prompt. - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, agent, filtered.instructions, filtered.input + history.bind_live_input_buffer(filtered.input) + try: + # Call hook just before the model is invoked, with the correct system_prompt. + llm_start_marker = history.begin_injection_stage("before_llm") + try: + await asyncio.gather( + hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ), + ( + agent.hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ) + if agent.hooks + else _coro.noop_coroutine() + ), ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + finally: + history.end_injection_stage(llm_start_marker) - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) - # 1. Stream the output events - async for event in model.stream_response( - filtered.instructions, - filtered.input, - model_settings, - all_tools, - output_schema, - handoffs, - get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ): - # Emit the raw event ASAP - streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - - if isinstance(event, ResponseCompletedEvent): - usage = ( - Usage( - requests=1, - input_tokens=event.response.usage.input_tokens, - output_tokens=event.response.usage.output_tokens, - total_tokens=event.response.usage.total_tokens, - input_tokens_details=event.response.usage.input_tokens_details, - output_tokens_details=event.response.usage.output_tokens_details, + # 1. Stream the output events + async for event in model.stream_response( + filtered.instructions, + filtered.input, + model_settings, + all_tools, + output_schema, + handoffs, + get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ): + # Emit the raw event ASAP + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + + if isinstance(event, ResponseCompletedEvent): + usage = ( + Usage( + requests=1, + input_tokens=event.response.usage.input_tokens, + output_tokens=event.response.usage.output_tokens, + total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, + ) + if event.response.usage + else Usage() ) - if event.response.usage - else Usage() - ) - final_response = ModelResponse( - output=event.response.output, - usage=usage, - response_id=event.response.id, - ) - context_wrapper.usage.add(usage) - - if isinstance(event, ResponseOutputItemDoneEvent): - output_item = event.item - - if isinstance(output_item, _TOOL_CALL_TYPES): - call_id: str | None = getattr( - output_item, "call_id", getattr(output_item, "id", None) + final_response = ModelResponse( + output=event.response.output, + usage=usage, + response_id=event.response.id, ) + context_wrapper.usage.add(usage) - if call_id and call_id not in emitted_tool_call_ids: - emitted_tool_call_ids.add(call_id) + if isinstance(event, ResponseOutputItemDoneEvent): + output_item = event.item - tool_item = ToolCallItem( - raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, - ) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=tool_item, name="tool_called") + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) ) - elif isinstance(output_item, ResponseReasoningItem): - reasoning_id: str | None = getattr(output_item, "id", None) + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) - if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: - emitted_reasoning_item_ids.add(reasoning_id) + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") - ) + elif isinstance(output_item, ResponseReasoningItem): + reasoning_id: str | None = getattr(output_item, "id", None) + + if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: + emitted_reasoning_item_ids.add(reasoning_id) + + reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent( + item=reasoning_item, name="reasoning_item_created" + ) + ) + finally: + history.release_live_input_buffer() # Call hook just after the model response is finalized. if final_response is not None: - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, final_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, final_response), - ) + llm_end_marker = history.begin_injection_stage("after_llm") + try: + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, final_response), + ) + finally: + history.end_injection_stage(llm_end_marker) # 2. At this point, the streaming is complete for this turn of the agent loop. if not final_response: @@ -1465,6 +1564,10 @@ async def _run_single_turn_streamed( event_queue=streamed_result._event_queue, ) + injected_records = history.flush_pending_items() + if injected_records: + _insert_injected_items(single_step_result.new_step_items, injected_records) + import dataclasses as _dc # Filter out items that have already been sent to avoid duplicates @@ -1523,34 +1626,56 @@ async def _run_single_turn( tool_use_tracker: AgentToolUseTracker, server_conversation_tracker: _ServerConversationTracker | None = None, ) -> SingleStepResult: - # Ensure we run the hooks before anything else + history = context_wrapper.message_history if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + marker = history.begin_injection_stage("agent_start") + try: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + finally: + history.end_injection_stage(marker) system_prompt, prompt_config = await asyncio.gather( agent.get_system_prompt(context_wrapper), agent.get_prompt(context_wrapper), ) + history.set_original_input(original_input) + history.bind_generated_items(generated_items) + + override_messages = history.consume_next_turn_override() + pending_input_items = history.pending_input_items() + output_schema = cls._get_output_schema(agent) handoffs = await cls._get_handoffs(agent, context_wrapper) + if server_conversation_tracker is not None: - input = server_conversation_tracker.prepare_input(original_input, generated_items) + if override_messages is not None: + raise UserError( + "message_history overrides are not supported when using conversation_id or " + "previous_response_id." + ) + model_input = server_conversation_tracker.prepare_input(original_input, generated_items) else: - input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) + if override_messages is not None: + model_input = override_messages + else: + model_input = ItemHelpers.input_to_new_input_list(original_input) + model_input.extend(item.to_input_item() for item in generated_items) + + if pending_input_items: + model_input.extend(pending_input_items) new_response = await cls._get_new_response( agent, system_prompt, - input, + model_input, output_schema, all_tools, handoffs, @@ -1562,7 +1687,7 @@ async def _run_single_turn( prompt_config, ) - return await cls._get_single_step_result_from_response( + single_step_result = await cls._get_single_step_result_from_response( agent=agent, original_input=original_input, pre_step_items=generated_items, @@ -1576,6 +1701,12 @@ async def _run_single_turn( tool_use_tracker=tool_use_tracker, ) + injected_records = history.flush_pending_items() + if injected_records: + _insert_injected_items(single_step_result.new_step_items, injected_records) + + return single_step_result + @classmethod async def _get_single_step_result_from_response( cls, @@ -1767,6 +1898,8 @@ async def _get_new_response( server_conversation_tracker: _ServerConversationTracker | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: + history = context_wrapper.message_history + # Allow user to modify model input right before the call, if configured filtered = await cls._maybe_filter_model_input( agent=agent, @@ -1780,56 +1913,71 @@ async def _get_new_response( model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, - agent, - filtered.instructions, # Use filtered instructions - filtered.input, # Use filtered input + history.bind_live_input_buffer(filtered.input) + try: + # If we have run hooks, or if the agent has hooks, + # we need to call them before the LLM call + llm_start_marker = history.begin_injection_stage("before_llm") + try: + await asyncio.gather( + hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ), + ( + agent.hooks.on_llm_start( + context_wrapper, + agent, + filtered.instructions, # Use filtered instructions + filtered.input, # Use filtered input + ) + if agent.hooks + else _coro.noop_coroutine() + ), ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) + finally: + history.end_injection_stage(llm_start_marker) - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ) + new_response = await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ) + finally: + history.release_live_input_buffer() context_wrapper.usage.add(new_response.usage) # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, new_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, new_response), - ) + llm_end_marker = history.begin_injection_stage("after_llm") + try: + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, new_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, new_response), + ) + finally: + history.end_injection_stage(llm_end_marker) return new_response diff --git a/src/agents/run_context.py b/src/agents/run_context.py index 579a215f2..b0fd74ebe 100644 --- a/src/agents/run_context.py +++ b/src/agents/run_context.py @@ -3,6 +3,7 @@ from typing_extensions import TypeVar +from .message_history import MessageHistory from .usage import Usage TContext = TypeVar("TContext", default=Any) @@ -24,3 +25,6 @@ class RunContextWrapper(Generic[TContext]): """The usage of the agent run so far. For streamed responses, the usage will be stale until the last chunk of the stream is processed. """ + + message_history: MessageHistory = field(default_factory=MessageHistory) + """Mutable conversation history that hooks can inspect and update.""" diff --git a/tests/test_message_history.py b/tests/test_message_history.py new file mode 100644 index 000000000..27c844f39 --- /dev/null +++ b/tests/test_message_history.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from collections.abc import MutableMapping +from typing import Any, cast + +import pytest + +from agents import Agent, Runner +from agents.items import InjectedInputItem, TResponseInputItem +from agents.lifecycle import AgentHooks, RunHooks +from agents.run_context import RunContextWrapper +from agents.stream_events import RunItemStreamEvent + +from .fake_model import FakeModel +from .test_responses import ( + get_function_tool, + get_function_tool_call, + get_text_message, +) + + +class MessageInjectionHooks(RunHooks): + def __init__(self, injected_text: str): + self.injected_text = injected_text + + async def on_llm_start( + self, + context: RunContextWrapper[Any], + agent: Agent[Any], + _instructions: str | None, + _input_items: list[TResponseInputItem], + ) -> None: + context.message_history.add_message(agent=agent, message=self.injected_text) + + +class OrderingAgentHooks(AgentHooks): + def __init__(self) -> None: + self.agent_start_text = "agent-start" + self.before_tool_text = "before-tool" + self.after_tool_text = "after-tool" + + async def on_start(self, context: RunContextWrapper[Any], agent: Agent[Any]) -> None: + context.message_history.add_message( + agent=agent, + message={"role": "developer", "content": self.agent_start_text}, + ) + + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + context.message_history.add_message( + agent=agent, + message={"role": "developer", "content": self.before_tool_text}, + ) + + async def on_tool_end( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: Any + ) -> None: + context.message_history.add_message( + agent=agent, + message={"role": "developer", "content": self.after_tool_text}, + ) + + +@pytest.mark.asyncio +async def test_run_hooks_can_inject_messages_into_llm_input() -> None: + hooks = MessageInjectionHooks("Moderator: cite your sources.") + model = FakeModel() + model.set_next_output([get_text_message("done")]) + + agent = Agent(name="editor", model=model) + result = await Runner.run(agent, input="original prompt", hooks=hooks) + + assert model.last_turn_args["input"] == [ + {"role": "user", "content": "original prompt"}, + {"role": "user", "content": hooks.injected_text}, + ] + + injected_items = [item for item in result.new_items if isinstance(item, InjectedInputItem)] + assert injected_items + + first_injected_raw = cast(MutableMapping[str, Any], injected_items[0].raw_item) + assert first_injected_raw["content"] == hooks.injected_text + + +@pytest.mark.asyncio +async def test_streamed_runs_emit_injected_input_items() -> None: + hooks = MessageInjectionHooks("Moderator: cite your sources.") + model = FakeModel() + model.set_next_output([get_text_message("done")]) + + agent = Agent(name="editor", model=model) + streamed_result = Runner.run_streamed(agent, input="streaming prompt", hooks=hooks) + + events: list[RunItemStreamEvent] = [] + async for event in streamed_result.stream_events(): + if isinstance(event, RunItemStreamEvent): + events.append(event) + + assert model.last_turn_args["input"] == [ + {"role": "user", "content": "streaming prompt"}, + {"role": "user", "content": hooks.injected_text}, + ] + + injected_items = [ + item for item in streamed_result.new_items if isinstance(item, InjectedInputItem) + ] + assert injected_items + assert any(isinstance(event.item, InjectedInputItem) for event in events) + + +def _find_index(items: list[TResponseInputItem], predicate) -> int: + for idx, item in enumerate(items): + if predicate(item): + return idx + raise AssertionError("predicate did not match any item") + + +@pytest.mark.asyncio +async def test_injected_messages_preserve_order_around_tool_calls() -> None: + tool = get_function_tool(name="helper", return_value="ok") + call_id = "call-123" + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call(tool.name, "{}", call_id=call_id)], + [get_text_message("done")], + ] + ) + + hooks = OrderingAgentHooks() + agent = Agent(name="tester", model=model, tools=[tool], hooks=hooks) + result = await Runner.run(agent, input="run") + + transcript = result.to_input_list() + agent_start_idx = _find_index( + transcript, + lambda item: item.get("role") == "developer" + and item.get("content") == hooks.agent_start_text, + ) + before_tool_idx = _find_index( + transcript, + lambda item: item.get("role") == "developer" + and item.get("content") == hooks.before_tool_text, + ) + tool_output_idx = _find_index( + transcript, + lambda item: item.get("type") == "function_call_output" and item.get("call_id") == call_id, + ) + after_tool_idx = _find_index( + transcript, + lambda item: item.get("role") == "developer" + and item.get("content") == hooks.after_tool_text, + ) + + assert agent_start_idx < before_tool_idx < tool_output_idx < after_tool_idx