diff --git a/examples/basic/usage_tracking.py b/examples/basic/usage_tracking.py index 30000b010..fd5a717c2 100644 --- a/examples/basic/usage_tracking.py +++ b/examples/basic/usage_tracking.py @@ -43,4 +43,3 @@ async def main() -> None: if __name__ == "__main__": asyncio.run(main()) - diff --git a/examples/model_providers/litellm_auto.py b/examples/model_providers/litellm_auto.py index 5e6942713..c9ab359d3 100644 --- a/examples/model_providers/litellm_auto.py +++ b/examples/model_providers/litellm_auto.py @@ -15,11 +15,13 @@ # import logging # logging.basicConfig(level=logging.DEBUG) + @function_tool def get_weather(city: str): print(f"[debug] getting weather for {city}") return f"The weather in {city} is sunny." + class Result(BaseModel): output_text: str tool_results: list[str] diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index d4ff47e80..abdb1dabb 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -198,9 +198,7 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str): {"type": "input_text", "text": prompt_text}, ] if prompt_text - else [ - {"type": "input_image", "image_url": data_url, "detail": "high"} - ] + else [{"type": "input_image", "image_url": data_url, "detail": "high"}] ), } await manager.send_user_message(session_id, user_msg) @@ -271,7 +269,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str): "role": "user", "content": ( [ - {"type": "input_image", "image_url": data_url, "detail": "high"}, + { + "type": "input_image", + "image_url": data_url, + "detail": "high", + }, {"type": "input_text", "text": prompt_text}, ] if prompt_text diff --git a/examples/realtime/cli/demo.py b/examples/realtime/cli/demo.py index 63bbd2099..4dfaf2c5a 100644 --- a/examples/realtime/cli/demo.py +++ b/examples/realtime/cli/demo.py @@ -23,8 +23,8 @@ FORMAT = np.int16 CHANNELS = 1 ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking -PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks) -FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting +PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks) +FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting # Set up logging for OpenAI agents SDK # logging.basicConfig( @@ -108,14 +108,18 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: samples, item_id, content_index = self.current_audio_chunk samples_filled = 0 - while samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples: + while ( + samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples + ): remaining_output = len(outdata) - samples_filled remaining_fade = self.fade_total_samples - self.fade_done_samples n = min(remaining_output, remaining_fade) src = samples[self.chunk_position : self.chunk_position + n].astype(np.float32) # Linear ramp from current level down to 0 across remaining fade samples - idx = np.arange(self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32) + idx = np.arange( + self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32 + ) gain = 1.0 - (idx / float(self.fade_total_samples)) ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16) outdata[samples_filled : samples_filled + n, 0] = ramped @@ -155,7 +159,10 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: if self.current_audio_chunk is None: try: # Respect a small jitter buffer before starting playback - if self.prebuffering and self.output_queue.qsize() < self.prebuffer_target_chunks: + if ( + self.prebuffering + and self.output_queue.qsize() < self.prebuffer_target_chunks + ): break self.prebuffering = False self.current_audio_chunk = self.output_queue.get_nowait() diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py index 8b344fc1b..13fa36e75 100644 --- a/src/agents/extensions/memory/__init__.py +++ b/src/agents/extensions/memory/__init__.py @@ -1,4 +1,3 @@ - """Session memory backends living in the extensions namespace. This package contains optional, production-grade session implementations that @@ -6,6 +5,7 @@ conform to the :class:`agents.memory.session.Session` protocol so they can be used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`. """ + from __future__ import annotations from .sqlalchemy_session import SQLAlchemySession # noqa: F401 diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index f13793ac9..4369b342b 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -413,9 +413,9 @@ def convert_message_to_openai( else: # Convert object to dict by accessing its attributes block_dict: dict[str, Any] = {} - if hasattr(block, '__dict__'): + if hasattr(block, "__dict__"): block_dict = dict(block.__dict__.items()) - elif hasattr(block, 'model_dump'): + elif hasattr(block, "model_dump"): block_dict = block.model_dump() else: # Last resort: convert to string representation diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 1c9f826de..77ff22ee0 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -106,6 +106,7 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon # Store thinking blocks in the reasoning item's content # Convert thinking blocks to Content objects from openai.types.responses.response_reasoning_item import Content + reasoning_item.content = [ Content(text=str(block.get("thinking", "")), type="reasoning_text") for block in message.thinking_blocks @@ -282,9 +283,7 @@ def extract_all_content( f"Only file_data is supported for input_file {casted_file_param}" ) if "filename" not in casted_file_param or not casted_file_param["filename"]: - raise UserError( - f"filename must be provided for input_file {casted_file_param}" - ) + raise UserError(f"filename must be provided for input_file {casted_file_param}") out.append( File( type="file", diff --git a/src/agents/realtime/model_events.py b/src/agents/realtime/model_events.py index a6d0bdecb..7c839aa18 100644 --- a/src/agents/realtime/model_events.py +++ b/src/agents/realtime/model_events.py @@ -84,6 +84,7 @@ class RealtimeModelInputAudioTranscriptionCompletedEvent: type: Literal["input_audio_transcription_completed"] = "input_audio_transcription_completed" + @dataclass class RealtimeModelInputAudioTimeoutTriggeredEvent: """Input audio timeout triggered.""" @@ -94,6 +95,7 @@ class RealtimeModelInputAudioTimeoutTriggeredEvent: type: Literal["input_audio_timeout_triggered"] = "input_audio_timeout_triggered" + @dataclass class RealtimeModelTranscriptDeltaEvent: """Partial transcript update.""" diff --git a/src/agents/result.py b/src/agents/result.py index 5cf0e74c8..9d57da13d 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -201,7 +201,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: break if isinstance(item, QueueCompleteSentinel): + # Await input guardrails if they are still running, so late exceptions are captured. + await self._await_task_safely(self._input_guardrails_task) + self._event_queue.task_done() + # Check for errors, in case the queue was completed due to an exception self._check_errors() break @@ -274,3 +278,19 @@ def _cleanup_tasks(self): def __str__(self) -> str: return pretty_print_run_result_streaming(self) + + async def _await_task_safely(self, task: asyncio.Task[Any] | None) -> None: + """Await a task if present, ignoring cancellation and storing exceptions elsewhere. + + This ensures we do not lose late guardrail exceptions while not surfacing + CancelledError to callers of stream_events. + """ + if task and not task.done(): + try: + await task + except asyncio.CancelledError: + # Task was cancelled (e.g., due to result.cancel()). Nothing to do here. + pass + except Exception: + # The exception will be surfaced via _check_errors() if needed. + pass diff --git a/src/agents/run.py b/src/agents/run.py index 5056758fb..1027b2355 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1127,14 +1127,11 @@ async def _run_single_turn_streamed( # Filter out HandoffCallItem to avoid duplicates (already sent earlier) items_to_filter = [ - item for item in items_to_filter - if not isinstance(item, HandoffCallItem) + item for item in items_to_filter if not isinstance(item, HandoffCallItem) ] # Create filtered result and send to queue - filtered_result = _dc.replace( - single_step_result, new_step_items=items_to_filter - ) + filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter) RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) return single_step_result @@ -1235,8 +1232,7 @@ async def _get_single_step_result_from_response( # Send handoff items immediately for streaming, but avoid duplicates if event_queue is not None and processed_response.new_items: handoff_items = [ - item for item in processed_response.new_items - if isinstance(item, HandoffCallItem) + item for item in processed_response.new_items if isinstance(item, HandoffCallItem) ] if handoff_items: RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) diff --git a/tests/test_agent_instructions_signature.py b/tests/test_agent_instructions_signature.py index bd16f9f57..604eb5189 100644 --- a/tests/test_agent_instructions_signature.py +++ b/tests/test_agent_instructions_signature.py @@ -16,6 +16,7 @@ def mock_run_context(self): @pytest.mark.asyncio async def test_valid_async_signature_passes(self, mock_run_context): """Test that async function with correct signature works""" + async def valid_instructions(context, agent): return "Valid async instructions" @@ -26,6 +27,7 @@ async def valid_instructions(context, agent): @pytest.mark.asyncio async def test_valid_sync_signature_passes(self, mock_run_context): """Test that sync function with correct signature works""" + def valid_instructions(context, agent): return "Valid sync instructions" @@ -36,6 +38,7 @@ def valid_instructions(context, agent): @pytest.mark.asyncio async def test_one_parameter_raises_error(self, mock_run_context): """Test that function with only one parameter raises TypeError""" + def invalid_instructions(context): return "Should fail" @@ -50,6 +53,7 @@ def invalid_instructions(context): @pytest.mark.asyncio async def test_three_parameters_raises_error(self, mock_run_context): """Test that function with three parameters raises TypeError""" + def invalid_instructions(context, agent, extra): return "Should fail" @@ -64,6 +68,7 @@ def invalid_instructions(context, agent, extra): @pytest.mark.asyncio async def test_zero_parameters_raises_error(self, mock_run_context): """Test that function with no parameters raises TypeError""" + def invalid_instructions(): return "Should fail" @@ -78,6 +83,7 @@ def invalid_instructions(): @pytest.mark.asyncio async def test_function_with_args_kwargs_fails(self, mock_run_context): """Test that function with *args/**kwargs fails validation""" + def flexible_instructions(context, agent, *args, **kwargs): return "Flexible instructions" diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index ff807ca96..90071a3d7 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from typing import Any @@ -523,6 +524,35 @@ def guardrail_function( pass +@pytest.mark.asyncio +async def test_slow_input_guardrail_still_raises_exception_streamed(): + async def guardrail_function( + context: RunContextWrapper[Any], agent: Agent[Any], input: Any + ) -> GuardrailFunctionOutput: + # Simulate a slow guardrail that completes after model streaming ends. + await asyncio.sleep(0.05) + return GuardrailFunctionOutput( + output_info=None, + tripwire_triggered=True, + ) + + model = FakeModel() + # Ensure the model finishes streaming quickly. + model.set_next_output([get_text_message("ok")]) + + agent = Agent( + name="test", + input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)], + model=model, + ) + + # Even though the guardrail is slower than the model stream, the exception should still raise. + with pytest.raises(InputGuardrailTripwireTriggered): + result = Runner.run_streamed(agent, input="user_message") + async for _ in result.stream_events(): + pass + + @pytest.mark.asyncio async def test_output_guardrail_tripwire_triggered_causes_exception_streamed(): def guardrail_function( diff --git a/tests/test_session.py b/tests/test_session.py index 5e96d3f25..40c0dc779 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -445,6 +445,7 @@ def filter_assistant_messages(history, new_input): assert len(model.last_turn_args["input"]) == 2 assert model.last_turn_args["input"] == expected_model_input + @pytest.mark.asyncio async def test_sqlite_session_unicode_content(): """Test that session correctly stores and retrieves unicode/non-ASCII content."""