Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/basic/usage_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ async def main() -> None:

if __name__ == "__main__":
asyncio.run(main())

2 changes: 2 additions & 0 deletions examples/model_providers/litellm_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
# import logging
# logging.basicConfig(level=logging.DEBUG)


@function_tool
def get_weather(city: str):
print(f"[debug] getting weather for {city}")
return f"The weather in {city} is sunny."


class Result(BaseModel):
output_text: str
tool_results: list[str]
Expand Down
10 changes: 6 additions & 4 deletions examples/realtime/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
{"type": "input_text", "text": prompt_text},
]
if prompt_text
else [
{"type": "input_image", "image_url": data_url, "detail": "high"}
]
else [{"type": "input_image", "image_url": data_url, "detail": "high"}]
),
}
await manager.send_user_message(session_id, user_msg)
Expand Down Expand Up @@ -271,7 +269,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str):
"role": "user",
"content": (
[
{"type": "input_image", "image_url": data_url, "detail": "high"},
{
"type": "input_image",
"image_url": data_url,
"detail": "high",
},
{"type": "input_text", "text": prompt_text},
]
if prompt_text
Expand Down
17 changes: 12 additions & 5 deletions examples/realtime/cli/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
FORMAT = np.int16
CHANNELS = 1
ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking
PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks)
FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting
PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks)
FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting

# Set up logging for OpenAI agents SDK
# logging.basicConfig(
Expand Down Expand Up @@ -108,14 +108,18 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:

samples, item_id, content_index = self.current_audio_chunk
samples_filled = 0
while samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples:
while (
samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples
):
remaining_output = len(outdata) - samples_filled
remaining_fade = self.fade_total_samples - self.fade_done_samples
n = min(remaining_output, remaining_fade)

src = samples[self.chunk_position : self.chunk_position + n].astype(np.float32)
# Linear ramp from current level down to 0 across remaining fade samples
idx = np.arange(self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32)
idx = np.arange(
self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32
)
gain = 1.0 - (idx / float(self.fade_total_samples))
ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16)
outdata[samples_filled : samples_filled + n, 0] = ramped
Expand Down Expand Up @@ -155,7 +159,10 @@ def _output_callback(self, outdata, frames: int, time, status) -> None:
if self.current_audio_chunk is None:
try:
# Respect a small jitter buffer before starting playback
if self.prebuffering and self.output_queue.qsize() < self.prebuffer_target_chunks:
if (
self.prebuffering
and self.output_queue.qsize() < self.prebuffer_target_chunks
):
break
self.prebuffering = False
self.current_audio_chunk = self.output_queue.get_nowait()
Expand Down
2 changes: 1 addition & 1 deletion src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

"""Session memory backends living in the extensions namespace.

This package contains optional, production-grade session implementations that
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
conform to the :class:`agents.memory.session.Session` protocol so they can be
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
"""

from __future__ import annotations

from .sqlalchemy_session import SQLAlchemySession # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,9 @@ def convert_message_to_openai(
else:
# Convert object to dict by accessing its attributes
block_dict: dict[str, Any] = {}
if hasattr(block, '__dict__'):
if hasattr(block, "__dict__"):
block_dict = dict(block.__dict__.items())
elif hasattr(block, 'model_dump'):
elif hasattr(block, "model_dump"):
block_dict = block.model_dump()
else:
# Last resort: convert to string representation
Expand Down
5 changes: 2 additions & 3 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon
# Store thinking blocks in the reasoning item's content
# Convert thinking blocks to Content objects
from openai.types.responses.response_reasoning_item import Content

reasoning_item.content = [
Content(text=str(block.get("thinking", "")), type="reasoning_text")
for block in message.thinking_blocks
Expand Down Expand Up @@ -282,9 +283,7 @@ def extract_all_content(
f"Only file_data is supported for input_file {casted_file_param}"
)
if "filename" not in casted_file_param or not casted_file_param["filename"]:
raise UserError(
f"filename must be provided for input_file {casted_file_param}"
)
raise UserError(f"filename must be provided for input_file {casted_file_param}")
out.append(
File(
type="file",
Expand Down
2 changes: 2 additions & 0 deletions src/agents/realtime/model_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class RealtimeModelInputAudioTranscriptionCompletedEvent:

type: Literal["input_audio_transcription_completed"] = "input_audio_transcription_completed"


@dataclass
class RealtimeModelInputAudioTimeoutTriggeredEvent:
"""Input audio timeout triggered."""
Expand All @@ -94,6 +95,7 @@ class RealtimeModelInputAudioTimeoutTriggeredEvent:

type: Literal["input_audio_timeout_triggered"] = "input_audio_timeout_triggered"


@dataclass
class RealtimeModelTranscriptDeltaEvent:
"""Partial transcript update."""
Expand Down
20 changes: 20 additions & 0 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
break

if isinstance(item, QueueCompleteSentinel):
# Await input guardrails if they are still running, so late exceptions are captured.
await self._await_task_safely(self._input_guardrails_task)

self._event_queue.task_done()

# Check for errors, in case the queue was completed due to an exception
self._check_errors()
break
Expand Down Expand Up @@ -274,3 +278,19 @@ def _cleanup_tasks(self):

def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

async def _await_task_safely(self, task: asyncio.Task[Any] | None) -> None:
"""Await a task if present, ignoring cancellation and storing exceptions elsewhere.

This ensures we do not lose late guardrail exceptions while not surfacing
CancelledError to callers of stream_events.
"""
if task and not task.done():
try:
await task
except asyncio.CancelledError:
# Task was cancelled (e.g., due to result.cancel()). Nothing to do here.
pass
except Exception:
# The exception will be surfaced via _check_errors() if needed.
pass
10 changes: 3 additions & 7 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,14 +1127,11 @@ async def _run_single_turn_streamed(

# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
items_to_filter = [
item for item in items_to_filter
if not isinstance(item, HandoffCallItem)
item for item in items_to_filter if not isinstance(item, HandoffCallItem)
]

# Create filtered result and send to queue
filtered_result = _dc.replace(
single_step_result, new_step_items=items_to_filter
)
filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter)
RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue)
return single_step_result

Expand Down Expand Up @@ -1235,8 +1232,7 @@ async def _get_single_step_result_from_response(
# Send handoff items immediately for streaming, but avoid duplicates
if event_queue is not None and processed_response.new_items:
handoff_items = [
item for item in processed_response.new_items
if isinstance(item, HandoffCallItem)
item for item in processed_response.new_items if isinstance(item, HandoffCallItem)
]
if handoff_items:
RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_agent_instructions_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def mock_run_context(self):
@pytest.mark.asyncio
async def test_valid_async_signature_passes(self, mock_run_context):
"""Test that async function with correct signature works"""

async def valid_instructions(context, agent):
return "Valid async instructions"

Expand All @@ -26,6 +27,7 @@ async def valid_instructions(context, agent):
@pytest.mark.asyncio
async def test_valid_sync_signature_passes(self, mock_run_context):
"""Test that sync function with correct signature works"""

def valid_instructions(context, agent):
return "Valid sync instructions"

Expand All @@ -36,6 +38,7 @@ def valid_instructions(context, agent):
@pytest.mark.asyncio
async def test_one_parameter_raises_error(self, mock_run_context):
"""Test that function with only one parameter raises TypeError"""

def invalid_instructions(context):
return "Should fail"

Expand All @@ -50,6 +53,7 @@ def invalid_instructions(context):
@pytest.mark.asyncio
async def test_three_parameters_raises_error(self, mock_run_context):
"""Test that function with three parameters raises TypeError"""

def invalid_instructions(context, agent, extra):
return "Should fail"

Expand All @@ -64,6 +68,7 @@ def invalid_instructions(context, agent, extra):
@pytest.mark.asyncio
async def test_zero_parameters_raises_error(self, mock_run_context):
"""Test that function with no parameters raises TypeError"""

def invalid_instructions():
return "Should fail"

Expand All @@ -78,6 +83,7 @@ def invalid_instructions():
@pytest.mark.asyncio
async def test_function_with_args_kwargs_fails(self, mock_run_context):
"""Test that function with *args/**kwargs fails validation"""

def flexible_instructions(context, agent, *args, **kwargs):
return "Flexible instructions"

Expand Down
30 changes: 30 additions & 0 deletions tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
from typing import Any

Expand Down Expand Up @@ -523,6 +524,35 @@ def guardrail_function(
pass


@pytest.mark.asyncio
async def test_slow_input_guardrail_still_raises_exception_streamed():
async def guardrail_function(
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
) -> GuardrailFunctionOutput:
# Simulate a slow guardrail that completes after model streaming ends.
await asyncio.sleep(0.05)
return GuardrailFunctionOutput(
output_info=None,
tripwire_triggered=True,
)

model = FakeModel()
# Ensure the model finishes streaming quickly.
model.set_next_output([get_text_message("ok")])

agent = Agent(
name="test",
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
model=model,
)

# Even though the guardrail is slower than the model stream, the exception should still raise.
with pytest.raises(InputGuardrailTripwireTriggered):
result = Runner.run_streamed(agent, input="user_message")
async for _ in result.stream_events():
pass


@pytest.mark.asyncio
async def test_output_guardrail_tripwire_triggered_causes_exception_streamed():
def guardrail_function(
Expand Down
1 change: 1 addition & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def filter_assistant_messages(history, new_input):
assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"] == expected_model_input


@pytest.mark.asyncio
async def test_sqlite_session_unicode_content():
"""Test that session correctly stores and retrieves unicode/non-ASCII content."""
Expand Down