diff --git a/doc/api.rst b/doc/api.rst index 1f9c76c9c..863fdedbc 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -502,6 +502,7 @@ API Reference HuggingFaceEndpointTarget limit_requests_per_minute OpenAICompletionTarget + OpenAICompletionsAudioConfig OpenAIImageTarget OpenAIChatTarget OpenAIResponseTarget diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 0ea34a0b9..ebcb32bc4 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -23,6 +23,7 @@ from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget +from pyrit.prompt_target.openai.completions_audio_config import OpenAICompletionsAudioConfig from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget from pyrit.prompt_target.openai.openai_completion_target import OpenAICompletionTarget from pyrit.prompt_target.openai.openai_image_target import OpenAIImageTarget @@ -51,8 +52,9 @@ "HuggingFaceEndpointTarget", "limit_requests_per_minute", "OpenAICompletionTarget", - "OpenAIImageTarget", + "OpenAICompletionsAudioConfig", "OpenAIChatTarget", + "OpenAIImageTarget", "OpenAIResponseTarget", "OpenAIVideoTarget", "OpenAITTSTarget", diff --git a/pyrit/prompt_target/openai/completions_audio_config.py b/pyrit/prompt_target/openai/completions_audio_config.py new file mode 100644 index 000000000..454a6aa94 --- /dev/null +++ b/pyrit/prompt_target/openai/completions_audio_config.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Literal + +# Voices supported by OpenAI Chat Completions API audio output. +# See: https://platform.openai.com/docs/guides/text-to-speech#voice-options +CompletionsAudioVoice = Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"] +CompletionsAudioFormat = Literal["wav", "mp3", "flac", "opus", "pcm16"] + + +@dataclass +class OpenAICompletionsAudioConfig: + """ + Configuration for audio output from OpenAI Chat Completions API. + + When provided to OpenAIChatTarget, this enables audio output from models + that support it (e.g., gpt-4o-audio-preview). + + Note: This is specific to the Chat Completions API. The Responses API does not + support audio input or output. For real-time audio, use RealtimeTarget instead. + """ + + # The voice to use for audio output. Supported voices are: + # "alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar". + voice: CompletionsAudioVoice + + # The audio format for the response. Supported formats are: + # "wav", "mp3", "flac", "opus", "pcm16". Defaults to "wav". + audio_format: CompletionsAudioFormat = "wav" + + # If True, historical user messages that contain both audio and text will only send + # the text (transcript) to reduce bandwidth and token usage. The current (last) user + # message will still include audio. Defaults to True. + prefer_transcript_for_history: bool = True + + def to_extra_body_parameters(self) -> dict: + """ + Convert the config to extra_body_parameters format for OpenAI API. + + Returns: + dict: Parameters to include in the request body for audio output. + """ + return { + "modalities": ["text", "audio"], + "audio": {"voice": self.voice, "format": self.audio_format}, + } diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 3b17e40fb..e65652416 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import base64 +import json import logging from typing import Any, Dict, MutableSequence, Optional @@ -12,13 +14,16 @@ ) from pyrit.models import ( ChatMessage, + DataTypeSerializer, Message, MessagePiece, construct_response_from_request, + data_serializer_factory, ) from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p +from pyrit.prompt_target.openai.completions_audio_config import OpenAICompletionsAudioConfig from pyrit.prompt_target.openai.openai_target import OpenAITarget logger = logging.getLogger(__name__) @@ -68,6 +73,7 @@ def __init__( seed: Optional[int] = None, n: Optional[int] = None, is_json_supported: bool = True, + audio_response_config: Optional[OpenAICompletionsAudioConfig] = None, extra_body_parameters: Optional[dict[str, Any]] = None, **kwargs, ): @@ -110,6 +116,8 @@ def __init__( setting the response_format header. Official OpenAI models all support this, but if you are using this target with different models, is_json_supported should be set correctly to avoid issues when using adversarial infrastructure (e.g. Crescendo scorers will set this flag). + audio_response_config (OpenAICompletionsAudioConfig, Optional): Configuration for audio output from models + that support it (e.g., gpt-4o-audio-preview). When provided, enables audio modality in responses. extra_body_parameters (dict, Optional): Additional parameters to be included in the request body. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` @@ -143,6 +151,16 @@ def __init__( self._presence_penalty = presence_penalty self._seed = seed self._n = n + self._audio_response_config = audio_response_config + + # Merge audio config into extra_body_parameters if provided + if audio_response_config: + audio_params = audio_response_config.to_extra_body_parameters() + if extra_body_parameters: + extra_body_parameters = {**audio_params, **extra_body_parameters} + else: + extra_body_parameters = audio_params + self._extra_body_parameters = extra_body_parameters def _set_openai_env_configuration_vars(self) -> None: @@ -224,7 +242,7 @@ def _validate_response(self, response: Any, request: MessagePiece) -> Optional[M Checks for: - Missing choices - Invalid finish_reason - - Empty content + - At least one valid response type (text content, audio, or tool_calls) Args: response: The ChatCompletion response from OpenAI SDK. @@ -245,34 +263,196 @@ def _validate_response(self, response: Any, request: MessagePiece) -> Optional[M finish_reason = choice.finish_reason # Check finish_reason (content_filter is handled by _check_content_filter) - if finish_reason not in ["stop", "length", "content_filter"]: - # finish_reason="stop" means API returned complete message - # "length" means API returned incomplete message due to max_tokens limit + # "tool_calls" is valid when the model invokes functions + valid_finish_reasons = ["stop", "length", "content_filter", "tool_calls"] + if finish_reason not in valid_finish_reasons: raise PyritException( message=f"Unknown finish_reason {finish_reason} from response: {response.model_dump_json()}" ) - # Check for empty content - content = choice.message.content or "" - if not content: - logger.error("The chat returned an empty response.") - raise EmptyResponseException(message="The chat returned an empty response.") + # Check for at least one valid response type + has_content, has_audio, has_tool_calls = self._detect_response_content(choice.message) + + if not (has_content or has_audio or has_tool_calls): + logger.error("The chat returned an empty response (no content, audio, or tool_calls).") + raise EmptyResponseException( + message="The chat returned an empty response (no content, audio, or tool_calls)." + ) return None + def _detect_response_content(self, message: Any) -> tuple[bool, bool, bool]: + """ + Detect what content types are present in a ChatCompletion message. + + Args: + message: The message object from response.choices[0].message. + + Returns: + Tuple of (has_content, has_audio, has_tool_calls) booleans. + """ + has_content = bool(message.content) + has_audio = hasattr(message, "audio") and message.audio is not None + has_tool_calls = hasattr(message, "tool_calls") and message.tool_calls + return has_content, has_audio, has_tool_calls + + def _should_skip_sending_audio( + self, + *, + data_type: str, + role: str, + is_last_message: bool, + has_text_piece: bool, + ) -> bool: + """ + Determine if an audio_path piece should be skipped when building chat messages. + + Args: + data_type: The converted_value_data_type of the message piece. + role: The API role of the message (user, assistant, system). + is_last_message: Whether this is the last (current) message in the conversation. + has_text_piece: Whether the message contains a text piece (e.g., transcript). + + Returns: + True if the audio should be skipped, False if it should be included. + """ + if data_type != "audio_path": + return False + + # Skip audio for assistant messages - OpenAI only allows audio in user messages. + # For assistant responses, the transcript text piece should already be included. + if role == "assistant": + return True + + # Skip historical user audio if prefer_transcript_for_history is enabled and we have a transcript + if ( + role == "user" + and not is_last_message + and has_text_piece + and self._audio_response_config + and self._audio_response_config.prefer_transcript_for_history + ): + return True + + return False + async def _construct_message_from_response(self, response: Any, request: MessagePiece) -> Message: """ Construct a Message from a ChatCompletion response. + Handles multiple response types: + - Text content from message.content + - Audio transcript and audio file from message.audio + - Tool calls serialized as JSON from message.tool_calls + Args: response: The ChatCompletion response from OpenAI SDK. request: The original request MessagePiece. Returns: - Message: Constructed message with extracted content. + Message: Constructed message with one or more MessagePiece entries. + + Raises: + EmptyResponseException: If the response contains no content, audio, or tool calls. + """ + message = response.choices[0].message + has_content, has_audio, has_tool_calls = self._detect_response_content(message) + + pieces: list[MessagePiece] = [] + + # Handle text content + if has_content: + text_piece = construct_response_from_request( + request=request, + response_text_pieces=[message.content], + response_type="text", + ).message_pieces[0] + pieces.append(text_piece) + + # Handle audio response (transcript + saved audio file) + if has_audio: + audio_response = message.audio + + # Add transcript as text piece + audio_transcript: Optional[str] = getattr(audio_response, "transcript", None) + if audio_transcript: + transcript_piece = construct_response_from_request( + request=request, + response_text_pieces=[audio_transcript], + response_type="text", + ).message_pieces[0] + pieces.append(transcript_piece) + + # Save audio data and add as audio_path piece + audio_data: Optional[str] = getattr(audio_response, "data", None) + if audio_data: + audio_path = await self._save_audio_response_async(audio_data_base64=audio_data) + audio_piece = construct_response_from_request( + request=request, + response_text_pieces=[audio_path], + response_type="audio_path", + ).message_pieces[0] + pieces.append(audio_piece) + + # Handle tool calls; for completions it is always function at the time of writing + if has_tool_calls: + for tool_call in message.tool_calls: + tool_call_data = { + "type": "function", + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + tool_call_json = json.dumps(tool_call_data) + tool_piece = construct_response_from_request( + request=request, + response_text_pieces=[tool_call_json], + response_type="function_call", + ).message_pieces[0] + pieces.append(tool_piece) + + if not pieces: + raise EmptyResponseException(message="Failed to extract any response content.") + + return Message(message_pieces=pieces) + + async def _save_audio_response_async(self, *, audio_data_base64: str) -> str: """ - extracted_response = response.choices[0].message.content or "" - return construct_response_from_request(request=request, response_text_pieces=[extracted_response]) + Save audio data from an OpenAI audio response to a file. + + Args: + audio_data_base64: Base64-encoded audio data from message.audio.data. + + Returns: + str: The file path where the audio was saved. + """ + audio_bytes = base64.b64decode(audio_data_base64) + + # Determine the format from config, default to wav + audio_format = self._audio_response_config.audio_format if self._audio_response_config else "wav" + extension = f".{audio_format}" if audio_format != "pcm16" else ".wav" + + audio_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type="audio_path", + extension=extension, + ) + + if audio_format == "pcm16": + # Raw PCM needs WAV headers - OpenAI uses 24kHz mono PCM16 + await audio_serializer.save_formatted_audio( + data=audio_bytes, + num_channels=1, + sample_width=2, + sample_rate=24000, + ) + else: + # wav, mp3, flac, opus are already properly formatted + await audio_serializer.save_data(audio_bytes) + + return audio_serializer.value def is_json_response_supported(self) -> bool: """ @@ -362,13 +542,28 @@ async def _build_chat_messages_for_multi_modal_async(self, conversation: Mutable ValueError: If any message piece has an unsupported data type. """ chat_messages: list[dict] = [] - for message in conversation: + last_message_index = len(conversation) - 1 + + for message_index, message in enumerate(conversation): message_pieces = message.message_pieces + is_last_message = message_index == last_message_index + + # Check if this message has a text piece (transcript) alongside audio + has_text_piece = any(mp.converted_value_data_type == "text" for mp in message_pieces) content = [] role = None for message_piece in message_pieces: role = message_piece.api_role + + if self._should_skip_sending_audio( + data_type=message_piece.converted_value_data_type, + role=role, + is_last_message=is_last_message, + has_text_piece=has_text_piece, + ): + continue + if message_piece.converted_value_data_type == "text": entry = {"type": "text", "text": message_piece.converted_value} content.append(entry) @@ -377,6 +572,24 @@ async def _build_chat_messages_for_multi_modal_async(self, conversation: Mutable image_url_entry = {"url": data_base64_encoded_url} entry = {"type": "image_url", "image_url": image_url_entry} # type: ignore content.append(entry) + elif message_piece.converted_value_data_type == "audio_path": + ext = DataTypeSerializer.get_extension(message_piece.converted_value) + if not ext or ext.lower() not in [".wav", ".mp3"]: + raise ValueError( + f"Unsupported audio format: {ext}. " + "OpenAI Chat Completions API only supports .wav and .mp3 for audio input." + ) + audio_serializer = data_serializer_factory( + category="prompt-memory-entries", + value=message_piece.converted_value, + data_type="audio_path", + extension=ext, + ) + base64_data = await audio_serializer.read_data_base64() + audio_format = ext.lower().lstrip(".") + input_audio_entry = {"data": base64_data, "format": audio_format} + entry = {"type": "input_audio", "input_audio": input_audio_entry} # type: ignore + content.append(entry) else: raise ValueError( f"Multimodal data type {message_piece.converted_value_data_type} is not yet supported." @@ -433,8 +646,10 @@ def _validate_request(self, *, message: Message) -> None: # Some models may not support all of these for prompt_data_type in converted_prompt_data_types: - if prompt_data_type not in ["text", "image_path"]: - raise ValueError(f"This target only supports text and image_path. Received: {prompt_data_type}.") + if prompt_data_type not in ["text", "image_path", "audio_path"]: + raise ValueError( + f"This target only supports text, image_path, and audio_path. Received: {prompt_data_type}." + ) def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[Dict[str, Any]]: if not json_config.enabled: diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index db955456c..2d7da2f04 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -25,7 +25,10 @@ logger = logging.getLogger(__name__) -RealTimeVoice = Literal["alloy", "echo", "shimmer"] +# Voices supported by the OpenAI Realtime API. +# See: https://platform.openai.com/docs/guides/realtime-conversations#voice-options +# For best quality, OpenAI recommends using "marin" or "cedar". +RealTimeVoice = Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"] @dataclass diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index d4458807d..59778c993 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -647,7 +647,14 @@ def _validate_request(self, *, message: Message) -> None: """ # Some models may not support all of these; we accept them at the transport layer # so the Responses API can decide. We include reasoning and function_call_output now. - allowed_types = {"text", "image_path", "function_call", "tool_call", "function_call_output", "reasoning"} + allowed_types = { + "text", + "image_path", + "function_call", + "tool_call", + "function_call_output", + "reasoning", + } for message_piece in message.message_pieces: if message_piece.converted_value_data_type not in allowed_types: raise ValueError(f"Unsupported data type: {message_piece.converted_value_data_type}") diff --git a/tests/integration/targets/test_openai_chat_target_integration.py b/tests/integration/targets/test_openai_chat_target_integration.py new file mode 100644 index 000000000..d83315db3 --- /dev/null +++ b/tests/integration/targets/test_openai_chat_target_integration.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Integration tests for OpenAIChatTarget. + +These tests verify: +- Audio input/output functionality using models that support native audio modalities +- Tool calling functionality with function definitions +""" + +import json +import os +import uuid + +import pytest + +from pyrit.common.path import HOME_PATH +from pyrit.models import MessagePiece +from pyrit.prompt_target import OpenAIChatTarget, OpenAICompletionsAudioConfig + +# Path to sample audio file for testing +SAMPLE_AUDIO_FILE = HOME_PATH / "assets" / "converted_audio.wav" + + +@pytest.fixture() +def platform_openai_audio_args(): + """ + Fixture for OpenAI platform audio-capable model. + + Requires: + - PLATFORM_OPENAI_CHAT_ENDPOINT: The OpenAI API endpoint (e.g., https://api.openai.com/v1) + - PLATFORM_OPENAI_CHAT_KEY: The OpenAI API key + """ + endpoint = os.environ.get("PLATFORM_OPENAI_CHAT_ENDPOINT") + api_key = os.environ.get("PLATFORM_OPENAI_CHAT_KEY") + + if not endpoint or not api_key: + pytest.skip("PLATFORM_OPENAI_CHAT_ENDPOINT and PLATFORM_OPENAI_CHAT_KEY must be set") + + return { + "endpoint": endpoint, + "api_key": api_key, + "model_name": "gpt-audio", + } + + +@pytest.fixture() +def platform_openai_chat_args(): + """ + Fixture for OpenAI platform chat model (non-audio). + + Requires: + - PLATFORM_OPENAI_CHAT_ENDPOINT: The OpenAI API endpoint + - PLATFORM_OPENAI_CHAT_KEY: The OpenAI API key + """ + endpoint = os.environ.get("PLATFORM_OPENAI_CHAT_ENDPOINT") + api_key = os.environ.get("PLATFORM_OPENAI_CHAT_KEY") + + if not endpoint or not api_key: + pytest.skip("PLATFORM_OPENAI_CHAT_ENDPOINT and PLATFORM_OPENAI_CHAT_KEY must be set") + + return { + "endpoint": endpoint, + "api_key": api_key, + "model_name": "gpt-4o", + } + + +# ============================================================================ +# Audio Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_openai_chat_target_audio_multi_turn(sqlite_instance, platform_openai_audio_args): + """ + Test multi-turn conversation with audio output. + + This test verifies that: + 1. Multiple turns of conversation work with audio output + 2. Conversation history is properly maintained + 3. Audio is generated for each assistant response + """ + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav") + + target = OpenAIChatTarget( + **platform_openai_audio_args, + audio_response_config=audio_config, + ) + + conv_id = str(uuid.uuid4()) + + # First turn + user_piece1 = MessagePiece( + role="user", + original_value="Hello! What's your name?", + original_value_data_type="text", + conversation_id=conv_id, + ) + + result1 = await target.send_prompt_async(message=user_piece1.to_message()) + assert result1 is not None + assert len(result1) >= 1 + + # Verify first response has audio + audio_pieces1 = [p for p in result1[0].message_pieces if p.converted_value_data_type == "audio_path"] + assert len(audio_pieces1) >= 1, "First response should contain audio" + + # Second turn - send audio input + user_piece2 = MessagePiece( + role="user", + original_value=str(SAMPLE_AUDIO_FILE), + original_value_data_type="audio_path", + conversation_id=conv_id, + ) + + result2 = await target.send_prompt_async(message=user_piece2.to_message()) + assert result2 is not None + assert len(result2) >= 1 + + # Verify second response has audio + audio_pieces2 = [p for p in result2[0].message_pieces if p.converted_value_data_type == "audio_path"] + assert len(audio_pieces2) >= 1, "Second response should contain audio" + + +# ============================================================================ +# Tool Calling Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_openai_chat_target_tool_calling_multiple_tools(sqlite_instance, platform_openai_chat_args): + """ + Test that OpenAIChatTarget can handle multiple tool definitions. + + This test verifies that: + 1. Multiple tools can be defined + 2. The model selects the appropriate tool based on context + """ + # Define multiple tools + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"}, + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Get the current stock price for a given ticker symbol", + "parameters": { + "type": "object", + "properties": { + "ticker": {"type": "string", "description": "The stock ticker symbol, e.g. AAPL"}, + }, + "required": ["ticker"], + }, + }, + }, + ] + + target = OpenAIChatTarget( + **platform_openai_chat_args, + extra_body_parameters={"tools": tools, "tool_choice": "auto"}, + ) + + conv_id = str(uuid.uuid4()) + + # Send a prompt that should trigger the stock price tool + user_piece = MessagePiece( + role="user", + original_value="What's the current stock price for Microsoft (MSFT)?", + original_value_data_type="text", + conversation_id=conv_id, + ) + + result = await target.send_prompt_async(message=user_piece.to_message()) + assert result is not None + assert len(result) >= 1 + + # Find tool call pieces in the response + tool_call_pieces = [p for p in result[0].message_pieces if p.converted_value_data_type == "function_call"] + + # The model should have returned a tool call for stock price + assert len(tool_call_pieces) >= 1, "Response should contain at least one tool call" + + # Verify it selected the stock price tool + tool_call_data = json.loads(tool_call_pieces[0].converted_value) + assert tool_call_data["function"]["name"] == "get_stock_price" + assert "msft" in tool_call_data["function"]["arguments"].lower() diff --git a/tests/unit/target/test_completions_audio_config.py b/tests/unit/target/test_completions_audio_config.py new file mode 100644 index 000000000..fee1f2c08 --- /dev/null +++ b/tests/unit/target/test_completions_audio_config.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for OpenAICompletionsAudioConfig. + +Tests cover initialization, validation, and conversion to API parameters. +""" + +import pytest + +from pyrit.prompt_target import OpenAICompletionsAudioConfig + + +class TestOpenAICompletionsAudioConfigInit: + """Tests for OpenAICompletionsAudioConfig initialization.""" + + def test_init_with_required_params_only(self): + """Test initialization with only required parameters.""" + config = OpenAICompletionsAudioConfig(voice="alloy") + + assert config.voice == "alloy" + assert config.audio_format == "wav" # Default value + assert config.prefer_transcript_for_history is True # Default value + + def test_init_with_all_params(self): + """Test initialization with all parameters specified.""" + config = OpenAICompletionsAudioConfig( + voice="coral", + audio_format="mp3", + prefer_transcript_for_history=False, + ) + + assert config.voice == "coral" + assert config.audio_format == "mp3" + assert config.prefer_transcript_for_history is False + + @pytest.mark.parametrize( + "voice", + ["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"], + ) + def test_init_with_all_valid_voices(self, voice): + """Test that all valid voice options are accepted.""" + config = OpenAICompletionsAudioConfig(voice=voice) + assert config.voice == voice + + @pytest.mark.parametrize("audio_format", ["wav", "mp3", "flac", "opus", "pcm16"]) + def test_init_with_all_valid_formats(self, audio_format): + """Test that all valid audio format options are accepted.""" + config = OpenAICompletionsAudioConfig(voice="alloy", audio_format=audio_format) + assert config.audio_format == audio_format + + +class TestOpenAICompletionsAudioConfigToExtraBodyParameters: + """Tests for the to_extra_body_parameters method.""" + + def test_to_extra_body_parameters_basic(self): + """Test conversion to extra body parameters with defaults.""" + config = OpenAICompletionsAudioConfig(voice="alloy") + + params = config.to_extra_body_parameters() + + assert params == { + "modalities": ["text", "audio"], + "audio": {"voice": "alloy", "format": "wav"}, + } + + def test_to_extra_body_parameters_custom_format(self): + """Test conversion with custom format.""" + config = OpenAICompletionsAudioConfig(voice="coral", audio_format="mp3") + + params = config.to_extra_body_parameters() + + assert params == { + "modalities": ["text", "audio"], + "audio": {"voice": "coral", "format": "mp3"}, + } + + def test_to_extra_body_parameters_all_voices(self): + """Test that all voices produce valid parameters.""" + voices = ["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"] + + for voice in voices: + config = OpenAICompletionsAudioConfig(voice=voice) # type: ignore[arg-type] + params = config.to_extra_body_parameters() + + assert params["modalities"] == ["text", "audio"] + assert params["audio"]["voice"] == voice + assert params["audio"]["format"] == "wav" + + def test_to_extra_body_parameters_all_formats(self): + """Test that all formats produce valid parameters.""" + formats = ["wav", "mp3", "flac", "opus", "pcm16"] + + for audio_format in formats: + config = OpenAICompletionsAudioConfig(voice="alloy", audio_format=audio_format) # type: ignore[arg-type] + params = config.to_extra_body_parameters() + + assert params["audio"]["format"] == audio_format + + +class TestOpenAICompletionsAudioConfigPreferTranscript: + """Tests for the prefer_transcript_for_history attribute.""" + + def test_prefer_transcript_default_true(self): + """Test that prefer_transcript_for_history defaults to True.""" + config = OpenAICompletionsAudioConfig(voice="alloy") + assert config.prefer_transcript_for_history is True + + def test_prefer_transcript_can_be_false(self): + """Test that prefer_transcript_for_history can be set to False.""" + config = OpenAICompletionsAudioConfig(voice="alloy", prefer_transcript_for_history=False) + assert config.prefer_transcript_for_history is False + + def test_prefer_transcript_not_in_extra_body(self): + """Test that prefer_transcript_for_history is not included in extra_body_parameters.""" + config = OpenAICompletionsAudioConfig(voice="alloy", prefer_transcript_for_history=False) + params = config.to_extra_body_parameters() + + # prefer_transcript_for_history is for PyRIT internal use, not sent to API + assert "prefer_transcript_for_history" not in params + assert "prefer_transcript_for_history" not in params.get("audio", {}) diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 349c65361..dd62e790c 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -10,7 +10,8 @@ import httpx import pytest -from openai import BadRequestError, RateLimitError +from openai import APIStatusError, BadRequestError, ContentFilterFinishReasonError, RateLimitError +from openai.types.chat import ChatCompletion from unit.mocks import ( get_image_message_piece, get_sample_conversations, @@ -25,7 +26,12 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig -from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.prompt_target import ( + OpenAIChatTarget, + OpenAICompletionsAudioConfig, + OpenAIResponseTarget, + PromptChatTarget, +) def fake_construct_response_from_request(request, response_text_pieces): @@ -34,12 +40,12 @@ def fake_construct_response_from_request(request, response_text_pieces): def create_mock_completion(content: str = "hi", finish_reason: str = "stop"): """Helper to create a mock OpenAI completion response""" - from openai.types.chat import ChatCompletion - mock_completion = MagicMock(spec=ChatCompletion) mock_completion.choices = [MagicMock()] mock_completion.choices[0].finish_reason = finish_reason mock_completion.choices[0].message.content = content + mock_completion.choices[0].message.audio = None + mock_completion.choices[0].message.tool_calls = None mock_completion.model_dump_json.return_value = json.dumps( {"choices": [{"finish_reason": finish_reason, "message": {"content": content}}]} ) @@ -161,13 +167,13 @@ async def test_build_chat_messages_for_multi_modal(target: OpenAIChatTarget): @pytest.mark.asyncio async def test_build_chat_messages_for_multi_modal_with_unsupported_data_types(target: OpenAIChatTarget): - # Like an image_path, the audio_path requires a file, but doesn't validate any contents + # Use video_path which is truly not supported for multimodal chat entry = get_image_message_piece() - entry.converted_value_data_type = "audio_path" + entry.converted_value_data_type = "video_path" with pytest.raises(ValueError) as excinfo: await target._build_chat_messages_for_multi_modal_async([Message(message_pieces=[entry])]) - assert "Multimodal data type audio_path is not yet supported." in str(excinfo.value) + assert "Multimodal data type video_path is not yet supported." in str(excinfo.value) @pytest.mark.asyncio @@ -551,7 +557,7 @@ def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): with pytest.raises(ValueError) as excinfo: target._validate_request(message=message) - assert "This target only supports text and image_path." in str(excinfo.value), ( + assert "This target only supports text, image_path, and audio_path." in str(excinfo.value), ( "Error not raised for unsupported data types" ) @@ -666,9 +672,7 @@ async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): @pytest.mark.asyncio -async def test_send_prompt_async_other_http_error(monkeypatch): - from openai import APIStatusError - +async def test_send_prompt_async_other_http_error(patch_central_database): target = OpenAIChatTarget( model_name="gpt-4", endpoint="https://mock.azure.com/", @@ -792,8 +796,6 @@ def test_azure_endpoint_new_format_openai_v1(patch_central_database): def test_azure_responses_endpoint_format(patch_central_database): """Test that Azure responses endpoint format is handled correctly.""" with patch.dict(os.environ, {}, clear=True): - from pyrit.prompt_target import OpenAIResponseTarget - target = OpenAIResponseTarget( model_name="o4-mini", endpoint="https://test.openai.azure.com/openai/responses?api-version=2025-03-01-preview", @@ -807,8 +809,6 @@ def test_azure_responses_endpoint_format(patch_central_database): def test_azure_responses_endpoint_new_format(patch_central_database): """Test that Azure responses endpoint with /openai/v1 format is handled correctly.""" with patch.dict(os.environ, {}, clear=True): - from pyrit.prompt_target import OpenAIResponseTarget - target = OpenAIResponseTarget( model_name="o4-mini", endpoint="https://test.openai.azure.com/openai/v1?api-version=2025-03-01-preview", @@ -862,8 +862,6 @@ async def test_content_filter_finish_reason_error( target: OpenAIChatTarget, sample_conversations: MutableSequence[MessagePiece] ): """Test ContentFilterFinishReasonError from SDK is handled correctly.""" - from openai import ContentFilterFinishReasonError - message_piece = sample_conversations[0] message_piece.conversation_id = "test-conv-id" request = Message(message_pieces=[message_piece]) @@ -943,8 +941,6 @@ async def test_bad_request_with_string_content_filter( @pytest.mark.asyncio async def test_api_status_error_429(target: OpenAIChatTarget, sample_conversations: MutableSequence[MessagePiece]): """Test APIStatusError with status 429 raises RateLimitException.""" - from openai import APIStatusError - message_piece = sample_conversations[0] message_piece.conversation_id = "test-conv-id" request = Message(message_pieces=[message_piece]) @@ -967,8 +963,6 @@ async def test_api_status_error_429(target: OpenAIChatTarget, sample_conversatio @pytest.mark.asyncio async def test_api_status_error_non_429(target: OpenAIChatTarget, sample_conversations: MutableSequence[MessagePiece]): """Test APIStatusError with non-429 status is re-raised.""" - from openai import APIStatusError - message_piece = sample_conversations[0] message_piece.conversation_id = "test-conv-id" request = Message(message_pieces=[message_piece]) @@ -1169,3 +1163,472 @@ def test_get_identifier_includes_top_p_when_set(patch_central_database): identifier = target.get_identifier() assert identifier["top_p"] == 0.9 + + +# ============================================================================ +# Audio Response Config Tests +# ============================================================================ + + +def test_init_with_audio_response_config(patch_central_database): + """Test initialization with audio_response_config.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav") + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + assert target._audio_response_config is not None + assert target._audio_response_config.voice == "alloy" + assert target._audio_response_config.audio_format == "wav" + + +def test_init_audio_config_extra_body_params_merged(patch_central_database): + """Test that audio config parameters are merged with extra_body_parameters.""" + audio_config = OpenAICompletionsAudioConfig(voice="coral", audio_format="mp3") + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + extra_body_parameters={"custom_param": "value"}, + ) + + # The audio config should add modalities and audio to extra body + assert target._extra_body_parameters.get("modalities") == ["text", "audio"] + assert target._extra_body_parameters.get("audio") == {"voice": "coral", "format": "mp3"} + assert target._extra_body_parameters.get("custom_param") == "value" + + +@pytest.mark.asyncio +async def test_construct_request_body_with_audio_config(patch_central_database, dummy_text_message_piece: MessagePiece): + """Test that request body includes audio modalities when audio config is set.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav") + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + request = Message(message_pieces=[dummy_text_message_piece]) + jrc = _JsonResponseConfig.from_metadata(metadata=None) + + body = await target._construct_request_body(conversation=[request], json_config=jrc) + + assert body.get("modalities") == ["text", "audio"] + assert body.get("audio") == {"voice": "alloy", "format": "wav"} + + +# ============================================================================ +# Audio History Stripping Tests +# ============================================================================ + + +def test_should_skip_sending_audio_assistant_role(patch_central_database): + """Test that audio is always skipped for assistant messages.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav") + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Assistant audio should always be skipped regardless of other conditions + result = target._should_skip_sending_audio( + data_type="audio_path", + role="assistant", + is_last_message=True, + has_text_piece=True, + ) + assert result is True + + # Even when it's the last message with no text piece + result = target._should_skip_sending_audio( + data_type="audio_path", + role="assistant", + is_last_message=True, + has_text_piece=False, + ) + assert result is True + + +def test_should_skip_sending_audio_user_history_with_transcript(patch_central_database): + """Test that historical user audio is skipped when transcript exists and prefer_transcript_for_history is True.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav", prefer_transcript_for_history=True) + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Historical user audio with transcript should be skipped + result = target._should_skip_sending_audio( + data_type="audio_path", + role="user", + is_last_message=False, # Historical message + has_text_piece=True, # Has transcript + ) + assert result is True + + +def test_should_skip_sending_audio_user_history_without_transcript(patch_central_database): + """Test that historical user audio is NOT skipped when no transcript exists.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav", prefer_transcript_for_history=True) + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Historical user audio without transcript should NOT be skipped + result = target._should_skip_sending_audio( + data_type="audio_path", + role="user", + is_last_message=False, # Historical message + has_text_piece=False, # No transcript + ) + assert result is False + + +def test_should_skip_sending_audio_current_user_message(patch_central_database): + """Test that the current (last) user audio is NOT skipped.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav", prefer_transcript_for_history=True) + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Current user audio should NOT be skipped even with transcript + result = target._should_skip_sending_audio( + data_type="audio_path", + role="user", + is_last_message=True, # Current message + has_text_piece=True, + ) + assert result is False + + +def test_should_skip_sending_audio_prefer_transcript_disabled(patch_central_database): + """Test that audio is NOT skipped when prefer_transcript_for_history is False.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav", prefer_transcript_for_history=False) + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Historical user audio should NOT be skipped when prefer_transcript_for_history is False + result = target._should_skip_sending_audio( + data_type="audio_path", + role="user", + is_last_message=False, + has_text_piece=True, + ) + assert result is False + + +def test_should_skip_sending_audio_no_audio_config(patch_central_database): + """Test that audio is NOT skipped when no audio config is set.""" + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + + # Without audio config, historical audio should NOT be skipped (for user) + result = target._should_skip_sending_audio( + data_type="audio_path", + role="user", + is_last_message=False, + has_text_piece=True, + ) + assert result is False + + +def test_should_skip_sending_audio_non_audio_type(patch_central_database): + """Test that non-audio data types are never skipped by this method.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav") + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + # Text type should not be skipped + result = target._should_skip_sending_audio( + data_type="text", + role="user", + is_last_message=False, + has_text_piece=True, + ) + assert result is False + + # Image type should not be skipped + result = target._should_skip_sending_audio( + data_type="image_path", + role="assistant", + is_last_message=True, + has_text_piece=False, + ) + assert result is False + + +@pytest.mark.asyncio +async def test_build_chat_messages_strips_audio_from_history(patch_central_database): + """Test that audio is stripped from historical messages when building chat messages.""" + audio_config = OpenAICompletionsAudioConfig(voice="alloy", audio_format="wav", prefer_transcript_for_history=True) + target = OpenAIChatTarget( + model_name="gpt-4o-audio-preview", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + audio_response_config=audio_config, + ) + + conv_id = "test-conv-id" + + # Create a historical message with both text (transcript) and audio + historical_message = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="Hello from audio", + converted_value="Hello from audio", + original_value_data_type="text", + converted_value_data_type="text", + conversation_id=conv_id, + ), + MessagePiece( + role="user", + original_value="/path/to/audio.wav", + converted_value="/path/to/audio.wav", + original_value_data_type="audio_path", + converted_value_data_type="audio_path", + conversation_id=conv_id, + ), + ] + ) + + # Create the current (last) message with text + current_message = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="Follow up question", + converted_value="Follow up question", + original_value_data_type="text", + converted_value_data_type="text", + conversation_id=conv_id, + ), + ] + ) + + # Build chat messages + messages = await target._build_chat_messages_for_multi_modal_async([historical_message, current_message]) + + # Verify the historical message only has text (audio was stripped) + assert len(messages) == 2 + assert messages[0]["role"] == "user" + # The historical message should only have the text content, not the audio + historical_content = messages[0]["content"] + assert len(historical_content) == 1 + assert historical_content[0]["type"] == "text" + assert historical_content[0]["text"] == "Hello from audio" + + # Current message should just have text + assert messages[1]["content"][0]["type"] == "text" + assert messages[1]["content"][0]["text"] == "Follow up question" + + +# ============================================================================ +# Tool Calling Tests +# ============================================================================ + + +def create_mock_completion_with_tool_calls(tool_calls: list, finish_reason: str = "tool_calls"): + """Helper to create a mock OpenAI completion response with tool calls.""" + mock_completion = MagicMock(spec=ChatCompletion) + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].finish_reason = finish_reason + mock_completion.choices[0].message.content = None + mock_completion.choices[0].message.audio = None + mock_completion.choices[0].message.tool_calls = tool_calls + mock_completion.model_dump_json.return_value = json.dumps( + {"choices": [{"finish_reason": finish_reason, "message": {"content": None, "tool_calls": []}}]} + ) + return mock_completion + + +def create_mock_tool_call(call_id: str, function_name: str, arguments: str): + """Helper to create a mock tool call object.""" + mock_tool_call = MagicMock() + mock_tool_call.id = call_id + mock_tool_call.type = "function" + mock_tool_call.function = MagicMock() + mock_tool_call.function.name = function_name + mock_tool_call.function.arguments = arguments + return mock_tool_call + + +def test_validate_response_tool_calls_finish_reason(target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece): + """Test _validate_response accepts tool_calls finish_reason.""" + tool_call = create_mock_tool_call("call_123", "get_weather", '{"location": "NYC"}') + mock_response = create_mock_completion_with_tool_calls([tool_call], finish_reason="tool_calls") + + # Should not raise - tool_calls is a valid finish reason + result = target._validate_response(mock_response, dummy_text_message_piece) + assert result is None + + +def test_detect_response_content_with_tool_calls(target: OpenAIChatTarget): + """Test _detect_response_content correctly identifies tool calls.""" + mock_message = MagicMock() + mock_message.content = None + mock_message.audio = None + + tool_call = create_mock_tool_call("call_123", "get_weather", '{"location": "NYC"}') + mock_message.tool_calls = [tool_call] + + has_content, has_audio, has_tool_calls = target._detect_response_content(mock_message) + + assert not has_content + assert not has_audio + assert has_tool_calls # Truthy - returns the list itself + + +def test_detect_response_content_no_tool_calls(target: OpenAIChatTarget): + """Test _detect_response_content when tool_calls is None.""" + mock_message = MagicMock() + mock_message.content = "Hello" + mock_message.audio = None + mock_message.tool_calls = None + + has_content, has_audio, has_tool_calls = target._detect_response_content(mock_message) + + assert has_content + assert not has_audio + assert not has_tool_calls # Falsy - None + + +def test_detect_response_content_empty_tool_calls(target: OpenAIChatTarget): + """Test _detect_response_content when tool_calls is empty list.""" + mock_message = MagicMock() + mock_message.content = "Hello" + mock_message.audio = None + mock_message.tool_calls = [] + + has_content, has_audio, has_tool_calls = target._detect_response_content(mock_message) + + assert has_content + assert not has_audio + assert not has_tool_calls # Falsy - empty list + + +@pytest.mark.asyncio +async def test_construct_message_from_response_with_tool_calls( + target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece +): + """Test _construct_message_from_response extracts tool calls correctly.""" + tool_call = create_mock_tool_call("call_abc123", "get_current_weather", '{"location": "Seattle, WA"}') + mock_response = create_mock_completion_with_tool_calls([tool_call]) + + result = await target._construct_message_from_response(mock_response, dummy_text_message_piece) + + assert isinstance(result, Message) + assert len(result.message_pieces) == 1 + + piece = result.message_pieces[0] + assert piece.converted_value_data_type == "function_call" + + # Verify the serialized tool call data + tool_call_data = json.loads(piece.converted_value) + assert tool_call_data["type"] == "function" + assert tool_call_data["id"] == "call_abc123" + assert tool_call_data["function"]["name"] == "get_current_weather" + assert tool_call_data["function"]["arguments"] == '{"location": "Seattle, WA"}' + + +@pytest.mark.asyncio +async def test_construct_message_from_response_with_multiple_tool_calls( + target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece +): + """Test _construct_message_from_response handles multiple tool calls.""" + tool_call1 = create_mock_tool_call("call_1", "get_weather", '{"location": "NYC"}') + tool_call2 = create_mock_tool_call("call_2", "get_time", '{"timezone": "EST"}') + mock_response = create_mock_completion_with_tool_calls([tool_call1, tool_call2]) + + result = await target._construct_message_from_response(mock_response, dummy_text_message_piece) + + assert isinstance(result, Message) + assert len(result.message_pieces) == 2 + + # Verify both tool calls are present + tool_call_data_1 = json.loads(result.message_pieces[0].converted_value) + tool_call_data_2 = json.loads(result.message_pieces[1].converted_value) + + assert tool_call_data_1["function"]["name"] == "get_weather" + assert tool_call_data_2["function"]["name"] == "get_time" + + +@pytest.mark.asyncio +async def test_send_prompt_with_tool_calls(target: OpenAIChatTarget): + """Test send_prompt_async correctly handles tool call responses.""" + tool_call = create_mock_tool_call("call_test", "search_web", '{"query": "PyRIT documentation"}') + mock_response = create_mock_completion_with_tool_calls([tool_call]) + + target._async_client.chat.completions.create = AsyncMock(return_value=mock_response) # type: ignore[method-assign] + + message = Message( + message_pieces=[ + MessagePiece( + role="user", + conversation_id="test-conv", + original_value="Search for PyRIT documentation", + converted_value="Search for PyRIT documentation", + original_value_data_type="text", + converted_value_data_type="text", + ) + ] + ) + + result = await target.send_prompt_async(message=message) + + assert len(result) == 1 + assert len(result[0].message_pieces) == 1 + assert result[0].message_pieces[0].converted_value_data_type == "function_call" + + tool_call_data = json.loads(result[0].message_pieces[0].converted_value) + assert tool_call_data["function"]["name"] == "search_web" + + +def test_construct_request_body_with_tools(patch_central_database): + """Test that tools are included in request body when specified in extra_body_parameters.""" + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + extra_body_parameters={"tools": tools, "tool_choice": "auto"}, + ) + + assert target._extra_body_parameters.get("tools") == tools + assert target._extra_body_parameters.get("tool_choice") == "auto" diff --git a/tests/unit/target/test_prompt_target.py b/tests/unit/target/test_prompt_target.py index 3034257ed..e258d2e0c 100644 --- a/tests/unit/target/test_prompt_target.py +++ b/tests/unit/target/test_prompt_target.py @@ -91,6 +91,8 @@ async def test_send_prompt_with_system_calls_chat_complete( mock_choice.finish_reason = "stop" mock_message = MagicMock() mock_message.content = "hi" + mock_message.audio = None # Explicitly set to avoid MagicMock auto-creation + mock_message.tool_calls = None mock_choice.message = mock_message mock_response.choices = [mock_choice] @@ -129,6 +131,8 @@ async def test_send_prompt_async_with_delay( mock_choice.finish_reason = "stop" mock_message = MagicMock() mock_message.content = "hi" + mock_message.audio = None # Explicitly set to avoid MagicMock auto-creation + mock_message.tool_calls = None mock_choice.message = mock_message mock_response.choices = [mock_choice]