Skip to content

Commit d7136af

Browse files
mjschockclaude
andcommitted
feat: add RunState parameter support to Runner.run() methods
This commit integrates RunState into the Runner API, allowing runs to be resumed from a saved state. This is the final piece needed to make human-in-the-loop (HITL) tool approval fully functional. **Changes:** 1. **Import NextStepInterruption** (run.py:21-32) - Added NextStepInterruption to imports from _run_impl - Added RunState import 2. **Updated Method Signatures** (run.py:285-444) - Runner.run(): Added `RunState[TContext]` to input union type - Runner.run_sync(): Added `RunState[TContext]` to input union type - Runner.run_streamed(): Added `RunState[TContext]` to input union type - AgentRunner.run(): Added `RunState[TContext]` to input union type - AgentRunner.run_sync(): Added `RunState[TContext]` to input union type - AgentRunner.run_streamed(): Added `RunState[TContext]` to input union type 3. **RunState Resumption Logic** (run.py:524-584) - Check if input is RunState instance - Extract state fields when resuming: current_turn, original_input, generated_items, model_responses, context_wrapper - Prime server conversation tracker from model_responses if resuming - Cast context_wrapper to correct type after extraction 4. **Interruption Handling** (run.py:689-726) - Added `interruptions=[]` to successful RunResult creation - Added elif branch for NextStepInterruption - Return RunResult with interruptions when tool approval needed - Set final_output to None for interrupted runs 5. **RunResultStreaming Support** (run.py:879-918) - Handle RunState input for streaming runs - Added `interruptions=[]` field to RunResultStreaming creation - Extract original_input from RunState for result **How It Works:** When resuming from RunState: ```python run_state.approve(approval_item) result = await Runner.run(agent, run_state) ``` When a tool needs approval: 1. Run pauses at tool execution 2. Returns RunResult with interruptions=[ToolApprovalItem(...)] 3. User can inspect interruptions and approve/reject 4. User resumes by passing RunResult back to Runner.run() **Remaining Work:** - Add `state` property to RunResult for creating RunState from results - Add comprehensive tests - Add documentation/examples 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 422c17d commit d7136af

File tree

1 file changed

+78
-22
lines changed

1 file changed

+78
-22
lines changed

src/agents/run.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AgentToolUseTracker,
2323
NextStepFinalOutput,
2424
NextStepHandoff,
25+
NextStepInterruption,
2526
NextStepRunAgain,
2627
QueueCompleteSentinel,
2728
RunImpl,
@@ -65,6 +66,7 @@
6566
from .models.multi_provider import MultiProvider
6667
from .result import RunResult, RunResultStreaming
6768
from .run_context import RunContextWrapper, TContext
69+
from .run_state import RunState
6870
from .stream_events import (
6971
AgentUpdatedStreamEvent,
7072
RawResponsesStreamEvent,
@@ -296,7 +298,7 @@ class Runner:
296298
async def run(
297299
cls,
298300
starting_agent: Agent[TContext],
299-
input: str | list[TResponseInputItem],
301+
input: str | list[TResponseInputItem] | RunState[TContext],
300302
*,
301303
context: TContext | None = None,
302304
max_turns: int = DEFAULT_MAX_TURNS,
@@ -371,7 +373,7 @@ async def run(
371373
def run_sync(
372374
cls,
373375
starting_agent: Agent[TContext],
374-
input: str | list[TResponseInputItem],
376+
input: str | list[TResponseInputItem] | RunState[TContext],
375377
*,
376378
context: TContext | None = None,
377379
max_turns: int = DEFAULT_MAX_TURNS,
@@ -444,7 +446,7 @@ def run_sync(
444446
def run_streamed(
445447
cls,
446448
starting_agent: Agent[TContext],
447-
input: str | list[TResponseInputItem],
449+
input: str | list[TResponseInputItem] | RunState[TContext],
448450
context: TContext | None = None,
449451
max_turns: int = DEFAULT_MAX_TURNS,
450452
hooks: RunHooks[TContext] | None = None,
@@ -519,7 +521,7 @@ class AgentRunner:
519521
async def run(
520522
self,
521523
starting_agent: Agent[TContext],
522-
input: str | list[TResponseInputItem],
524+
input: str | list[TResponseInputItem] | RunState[TContext],
523525
**kwargs: Unpack[RunOptions[TContext]],
524526
) -> RunResult:
525527
context = kwargs.get("context")
@@ -532,19 +534,41 @@ async def run(
532534
if run_config is None:
533535
run_config = RunConfig()
534536

537+
# Check if we're resuming from a RunState
538+
is_resumed_state = isinstance(input, RunState)
539+
run_state: RunState[TContext] | None = None
540+
541+
if is_resumed_state:
542+
# Resuming from a saved state
543+
run_state = cast(RunState[TContext], input)
544+
original_user_input = run_state._original_input
545+
prepared_input = run_state._original_input
546+
547+
# Override context with the state's context if not provided
548+
if context is None and run_state._context is not None:
549+
context = run_state._context.context
550+
else:
551+
# Keep original user input separate from session-prepared input
552+
raw_input = cast(str | list[TResponseInputItem], input)
553+
original_user_input = raw_input
554+
prepared_input = await self._prepare_input_with_session(
555+
raw_input, session, run_config.session_input_callback
556+
)
557+
535558
if conversation_id is not None or previous_response_id is not None:
536559
server_conversation_tracker = _ServerConversationTracker(
537560
conversation_id=conversation_id, previous_response_id=previous_response_id
538561
)
539562
else:
540563
server_conversation_tracker = None
541564

542-
# Keep original user input separate from session-prepared input
543-
original_user_input = input
544-
prepared_input = await self._prepare_input_with_session(
545-
input, session, run_config.session_input_callback
546-
)
565+
# Prime the server conversation tracker from state if resuming
566+
if server_conversation_tracker is not None and is_resumed_state and run_state is not None:
567+
for response in run_state._model_responses:
568+
server_conversation_tracker.track_server_items(response)
547569

570+
# Always create a fresh tool_use_tracker
571+
# (it's rebuilt from the run state if needed during execution)
548572
tool_use_tracker = AgentToolUseTracker()
549573

550574
with TraceCtxManager(
@@ -554,14 +578,23 @@ async def run(
554578
metadata=run_config.trace_metadata,
555579
disabled=run_config.tracing_disabled,
556580
):
557-
current_turn = 0
558-
original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input)
559-
generated_items: list[RunItem] = []
560-
model_responses: list[ModelResponse] = []
561-
562-
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
563-
context=context, # type: ignore
564-
)
581+
if is_resumed_state and run_state is not None:
582+
# Restore state from RunState
583+
current_turn = run_state._current_turn
584+
original_input = run_state._original_input
585+
generated_items = run_state._generated_items
586+
model_responses = run_state._model_responses
587+
# Cast to the correct type since we know this is TContext
588+
context_wrapper = cast(RunContextWrapper[TContext], run_state._context)
589+
else:
590+
# Fresh run
591+
current_turn = 0
592+
original_input = _copy_str_or_list(prepared_input)
593+
generated_items = []
594+
model_responses = []
595+
context_wrapper = RunContextWrapper(
596+
context=context, # type: ignore
597+
)
565598

566599
input_guardrail_results: list[InputGuardrailResult] = []
567600
tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
@@ -704,6 +737,7 @@ async def run(
704737
tool_input_guardrail_results=tool_input_guardrail_results,
705738
tool_output_guardrail_results=tool_output_guardrail_results,
706739
context_wrapper=context_wrapper,
740+
interruptions=[],
707741
)
708742
if not any(
709743
guardrail_result.output.tripwire_triggered
@@ -712,7 +746,22 @@ async def run(
712746
await self._save_result_to_session(
713747
session, [], turn_result.new_step_items
714748
)
715-
749+
return result
750+
elif isinstance(turn_result.next_step, NextStepInterruption):
751+
# Tool approval is needed - return a result with interruptions
752+
result = RunResult(
753+
input=original_input,
754+
new_items=generated_items,
755+
raw_responses=model_responses,
756+
final_output=None,
757+
_last_agent=current_agent,
758+
input_guardrail_results=input_guardrail_results,
759+
output_guardrail_results=[],
760+
tool_input_guardrail_results=tool_input_guardrail_results,
761+
tool_output_guardrail_results=tool_output_guardrail_results,
762+
context_wrapper=context_wrapper,
763+
interruptions=turn_result.next_step.interruptions,
764+
)
716765
return result
717766
elif isinstance(turn_result.next_step, NextStepHandoff):
718767
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
@@ -756,7 +805,7 @@ async def run(
756805
def run_sync(
757806
self,
758807
starting_agent: Agent[TContext],
759-
input: str | list[TResponseInputItem],
808+
input: str | list[TResponseInputItem] | RunState[TContext],
760809
**kwargs: Unpack[RunOptions[TContext]],
761810
) -> RunResult:
762811
context = kwargs.get("context")
@@ -835,7 +884,7 @@ def run_sync(
835884
def run_streamed(
836885
self,
837886
starting_agent: Agent[TContext],
838-
input: str | list[TResponseInputItem],
887+
input: str | list[TResponseInputItem] | RunState[TContext],
839888
**kwargs: Unpack[RunOptions[TContext]],
840889
) -> RunResultStreaming:
841890
context = kwargs.get("context")
@@ -869,8 +918,14 @@ def run_streamed(
869918
context=context # type: ignore
870919
)
871920

921+
# Handle RunState input
922+
if isinstance(input, RunState):
923+
input_for_result = input._original_input
924+
else:
925+
input_for_result = input
926+
872927
streamed_result = RunResultStreaming(
873-
input=_copy_str_or_list(input),
928+
input=_copy_str_or_list(input_for_result),
874929
new_items=[],
875930
current_agent=starting_agent,
876931
raw_responses=[],
@@ -885,12 +940,13 @@ def run_streamed(
885940
_current_agent_output_schema=output_schema,
886941
trace=new_trace,
887942
context_wrapper=context_wrapper,
943+
interruptions=[],
888944
)
889945

890946
# Kick off the actual agent loop in the background and return the streamed result object.
891947
streamed_result._run_impl_task = asyncio.create_task(
892948
self._start_streaming(
893-
starting_input=input,
949+
starting_input=input_for_result,
894950
streamed_result=streamed_result,
895951
starting_agent=starting_agent,
896952
max_turns=max_turns,

0 commit comments

Comments
 (0)