Skip to content

Commit c5fa125

Browse files
Allow to pass both session and input list + tests
1 parent 114b320 commit c5fa125

File tree

2 files changed

+159
-15
lines changed

2 files changed

+159
-15
lines changed

src/agents/run.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import inspect
66
from dataclasses import dataclass, field
7-
from typing import Any, Generic, cast
7+
from typing import Any, Generic, Literal, cast
88

99
from openai.types.responses import ResponseCompletedEvent
1010
from openai.types.responses.response_prompt_param import (
@@ -139,6 +139,11 @@ class RunConfig:
139139
An optional dictionary of additional metadata to include with the trace.
140140
"""
141141

142+
session_input_handling: Literal["replace", "append"] | None = None
143+
"""If a custom input list is given together with the Session, it will
144+
be appended to the session messages or it will replace them.
145+
"""
146+
142147

143148
class RunOptions(TypedDict, Generic[TContext]):
144149
"""Arguments for ``AgentRunner`` methods."""
@@ -343,7 +348,9 @@ async def run(
343348
run_config = RunConfig()
344349

345350
# Prepare input with session if enabled
346-
prepared_input = await self._prepare_input_with_session(input, session)
351+
prepared_input = await self._prepare_input_with_session(
352+
input, session, run_config.session_input_handling
353+
)
347354

348355
tool_use_tracker = AgentToolUseTracker()
349356

@@ -468,7 +475,9 @@ async def run(
468475
)
469476

470477
# Save the conversation to session if enabled
471-
await self._save_result_to_session(session, input, result)
478+
await self._save_result_to_session(
479+
session, input, result, run_config.session_input_handling
480+
)
472481

473482
return result
474483
elif isinstance(turn_result.next_step, NextStepHandoff):
@@ -662,7 +671,9 @@ async def _start_streaming(
662671

663672
try:
664673
# Prepare input with session if enabled
665-
prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session)
674+
prepared_input = await AgentRunner._prepare_input_with_session(
675+
starting_input, session, run_config.session_input_handling
676+
)
666677

667678
# Update the streamed result with the prepared input
668679
streamed_result.input = prepared_input
@@ -781,7 +792,7 @@ async def _start_streaming(
781792
context_wrapper=context_wrapper,
782793
)
783794
await AgentRunner._save_result_to_session(
784-
session, starting_input, temp_result
795+
session, starting_input, temp_result, run_config.session_input_handling
785796
)
786797

787798
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1191,18 +1202,18 @@ async def _prepare_input_with_session(
11911202
cls,
11921203
input: str | list[TResponseInputItem],
11931204
session: Session | None,
1205+
session_input_handling: Literal["replace", "append"] | None,
11941206
) -> str | list[TResponseInputItem]:
11951207
"""Prepare input by combining it with session history if enabled."""
11961208
if session is None:
11971209
return input
11981210

1199-
# Validate that we don't have both a session and a list input, as this creates
1200-
# ambiguity about whether the list should append to or replace existing session history
1201-
if isinstance(input, list):
1211+
# If the user doesn't explicitly specify a mode, raise an error
1212+
if isinstance(input, list) and not session_input_handling:
12021213
raise UserError(
1203-
"Cannot provide both a session and a list of input items. "
1204-
"When using session memory, provide only a string input to append to the "
1205-
"conversation, or use session=None and provide a list to manually manage "
1214+
"You must specify the `session_input_handling` in the `RunConfig`. "
1215+
"Otherwise, when using session memory, provide only a string input to append to "
1216+
"the conversation, or use session=None and provide a list to manually manage "
12061217
"conversation history."
12071218
)
12081219

@@ -1212,8 +1223,17 @@ async def _prepare_input_with_session(
12121223
# Convert input to list format
12131224
new_input_list = ItemHelpers.input_to_new_input_list(input)
12141225

1215-
# Combine history with new input
1216-
combined_input = history + new_input_list
1226+
if session_input_handling == "append" or session_input_handling is None:
1227+
# Append new input to history
1228+
combined_input = history + new_input_list
1229+
elif session_input_handling == "replace":
1230+
# Replace history with new input
1231+
combined_input = new_input_list
1232+
else:
1233+
raise UserError(
1234+
"The specified `session_input_handling` is not available. "
1235+
"Choose between `append`, `replace` or `None`."
1236+
)
12171237

12181238
return combined_input
12191239

@@ -1223,11 +1243,16 @@ async def _save_result_to_session(
12231243
session: Session | None,
12241244
original_input: str | list[TResponseInputItem],
12251245
result: RunResult,
1246+
saving_mode: Literal["replace", "append"] | None = None,
12261247
) -> None:
12271248
"""Save the conversation turn to session."""
12281249
if session is None:
12291250
return
12301251

1252+
# Remove old history
1253+
if saving_mode == "replace":
1254+
await session.clear_session()
1255+
12311256
# Convert original input to list format if needed
12321257
input_list = ItemHelpers.input_to_new_input_list(original_input)
12331258

tests/test_session.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
9+
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
1010
from agents.exceptions import UserError
1111

1212
from .fake_model import FakeModel
@@ -394,7 +394,126 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
394394
await run_agent_async(runner_method, agent, list_input, session=session)
395395

396396
# Verify the error message explains the issue
397-
assert "Cannot provide both a session and a list of input items" in str(exc_info.value)
397+
assert "You must specify the `session_input_handling` in" in str(exc_info.value)
398398
assert "manually manage conversation history" in str(exc_info.value)
399399

400400
session.close()
401+
402+
403+
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
404+
@pytest.mark.asyncio
405+
async def test_session_memory_append_list(runner_method):
406+
"""Test if the user passes a list of items and want to append them."""
407+
with tempfile.TemporaryDirectory() as temp_dir:
408+
db_path = Path(temp_dir) / "test_memory.db"
409+
410+
model = FakeModel()
411+
agent = Agent(name="test", model=model)
412+
413+
# Session
414+
session_id = "session_1"
415+
session = SQLiteSession(session_id, db_path)
416+
417+
model.set_next_output([get_text_message("I like cats")])
418+
_ = await run_agent_async(runner_method, agent, "I like cats", session=session)
419+
420+
append_input = [
421+
{"role": "user", "content": "Some random user text"},
422+
{"role": "assistant", "content": "You're right"},
423+
{"role": "user", "content": "What did I say I like?"},
424+
]
425+
second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"}
426+
model.set_next_output([get_text_message(second_model_response.get("content", ""))])
427+
428+
_ = await run_agent_async(
429+
runner_method,
430+
agent,
431+
append_input,
432+
session=session,
433+
run_config=RunConfig(session_input_handling="append"),
434+
)
435+
436+
session_items = await session.get_items()
437+
438+
# Check the items has been appended
439+
assert len(session_items) == 6
440+
441+
# Check the items are the last 4 elements
442+
append_input.append(second_model_response)
443+
for sess_item, orig_item in zip(session_items[-4:], append_input):
444+
assert sess_item.get("role") == orig_item.get("role")
445+
446+
sess_content = sess_item.get("content")
447+
# Narrow to list or str for mypy
448+
assert isinstance(sess_content, (list, str))
449+
450+
if isinstance(sess_content, list):
451+
# now mypy knows `content: list[Any]`
452+
assert isinstance(sess_content[0], dict) and "text" in sess_content[0]
453+
val_sess = sess_content[0]["text"]
454+
else:
455+
# here content is str
456+
val_sess = sess_content
457+
458+
assert val_sess == orig_item["content"]
459+
460+
session.close()
461+
462+
463+
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
464+
@pytest.mark.asyncio
465+
async def test_session_memory_replace_list(runner_method):
466+
"""Test if the user passes a list of items and want to replace the history."""
467+
with tempfile.TemporaryDirectory() as temp_dir:
468+
db_path = Path(temp_dir) / "test_memory.db"
469+
470+
model = FakeModel()
471+
agent = Agent(name="test", model=model)
472+
473+
# Session
474+
session_id = "session_1"
475+
session = SQLiteSession(session_id, db_path)
476+
477+
model.set_next_output([get_text_message("I like cats")])
478+
_ = await run_agent_async(runner_method, agent, "I like cats", session=session)
479+
480+
replace_input = [
481+
{"role": "user", "content": "Some random user text"},
482+
{"role": "assistant", "content": "You're right"},
483+
{"role": "user", "content": "What did I say I like?"},
484+
]
485+
second_model_response = {"role": "assistant", "content": "Yes, you mentioned cats"}
486+
model.set_next_output([get_text_message(second_model_response.get("content", ""))])
487+
488+
_ = await run_agent_async(
489+
runner_method,
490+
agent,
491+
replace_input,
492+
session=session,
493+
run_config=RunConfig(session_input_handling="replace"),
494+
)
495+
496+
session_items = await session.get_items()
497+
498+
# Check the new items replaced the history
499+
assert len(session_items) == 4
500+
501+
# Check the items are the last 4 elements
502+
replace_input.append(second_model_response)
503+
for sess_item, orig_item in zip(session_items, replace_input):
504+
assert sess_item.get("role") == orig_item.get("role")
505+
sess_content = sess_item.get("content")
506+
# Narrow to list or str for mypy
507+
assert isinstance(sess_content, (list, str))
508+
509+
if isinstance(sess_content, list):
510+
# now mypy knows `content: list[Any]`
511+
assert isinstance(sess_content[0], dict) and "text" in sess_content[0]
512+
val_sess = sess_content[0]["text"]
513+
else:
514+
# here content is str
515+
val_sess = sess_content
516+
517+
assert val_sess == orig_item["content"]
518+
519+
session.close()

0 commit comments

Comments
 (0)