From 19b32d09d8c0ffed2b958e0bb658286dde7dcc78 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 18 Nov 2025 10:43:54 -0800 Subject: [PATCH 1/3] Use openai SDK in OpenAIChatTarget and OpenAIResponseTarget --- .../.vite/deps_temp_f8b3b81a/package.json | 3 + .../openai/openai_chat_target.py | 71 +++- .../openai/openai_chat_target_base.py | 115 +++++-- .../openai/openai_response_target.py | 41 ++- pyrit/prompt_target/openai/openai_target.py | 95 +++++- .../targets/test_targets_and_secrets.py | 5 + tests/unit/target/test_openai_chat_target.py | 319 ++++++++++-------- .../target/test_openai_response_target.py | 228 ++++++------- 8 files changed, 582 insertions(+), 295 deletions(-) create mode 100644 frontend/.vite/deps_temp_f8b3b81a/package.json diff --git a/frontend/.vite/deps_temp_f8b3b81a/package.json b/frontend/.vite/deps_temp_f8b3b81a/package.json new file mode 100644 index 000000000..3dbc1ca59 --- /dev/null +++ b/frontend/.vite/deps_temp_f8b3b81a/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index b28cf2565..5aa116414 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -270,13 +270,82 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} + async def _make_chat_completion_request(self, body: dict): + """ + Make the actual chat completion request using the OpenAI SDK. + + Args: + body (dict): The request body parameters. + + Returns: + The completion response from the OpenAI SDK. + """ + # Use the OpenAI SDK client to make the request + completion = await self._async_client.chat.completions.create(**body) + return completion + + def _construct_message_from_completion_response( + self, + *, + completion_response, + message_piece: MessagePiece, + ) -> Message: + """ + Construct a Message from the OpenAI SDK completion response. + + Args: + completion_response: The completion response from the OpenAI SDK (ChatCompletion object). + message_piece (MessagePiece): The original request message piece. + + Returns: + Message: The constructed message. + """ + # Extract the finish reason and content from the SDK response + if not completion_response.choices: + raise PyritException(message="No choices returned in the completion response.") + + choice = completion_response.choices[0] + finish_reason = choice.finish_reason + extracted_response: str = "" + + # finish_reason="stop" means API returned complete message and + # "length" means API returned incomplete message due to max_tokens limit. + if finish_reason in ["stop", "length"]: + extracted_response = choice.message.content or "" + + # Handle empty response + if not extracted_response: + logger.error("The chat returned an empty response.") + raise EmptyResponseException(message="The chat returned an empty response.") + elif finish_reason == "content_filter": + # Content filter with status 200 indicates that the model output was filtered + # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter + # Note: The SDK should raise ContentFilterFinishReasonError for this case, + # but we handle it here as a fallback + return handle_bad_request_exception( + response_text=completion_response.model_dump_json(), + request=message_piece, + error_code=200, + is_content_filter=True + ) + else: + raise PyritException( + message=f"Unknown finish_reason {finish_reason} from response: {completion_response.model_dump_json()}" + ) + + return construct_response_from_request(request=message_piece, response_text_pieces=[extracted_response]) + def _construct_message_from_openai_json( self, *, open_ai_str_response: str, message_piece: MessagePiece, ) -> Message: - + """ + Legacy method for backward compatibility. + Parses a JSON string response from OpenAI API. + The SDK-based implementation uses _construct_message_from_completion_response instead. + """ try: response = json.loads(open_ai_str_response) except json.JSONDecodeError as e: diff --git a/pyrit/prompt_target/openai/openai_chat_target_base.py b/pyrit/prompt_target/openai/openai_chat_target_base.py index 754e231be..96272e4d9 100644 --- a/pyrit/prompt_target/openai/openai_chat_target_base.py +++ b/pyrit/prompt_target/openai/openai_chat_target_base.py @@ -5,9 +5,13 @@ import logging from typing import Any, MutableSequence, Optional -import httpx +from openai import ( + BadRequestError, + RateLimitError, + ContentFilterFinishReasonError, + APIStatusError, +) -from pyrit.common import net_utility from pyrit.exceptions import ( PyritException, handle_bad_request_exception, @@ -120,54 +124,99 @@ async def send_prompt_async(self, *, message: Message) -> Message: body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response) try: - str_response: httpx.Response = await net_utility.make_request_and_raise_if_error_async( - endpoint_uri=self._endpoint, - method="POST", - headers=self._headers, - request_body=body, - **self._httpx_client_kwargs, + # Use the OpenAI SDK for making the request + response = await self._make_chat_completion_request(body) + + # Convert the SDK response to our Message format + return self._construct_message_from_completion_response( + completion_response=response, message_piece=message_piece + ) + + except ContentFilterFinishReasonError as e: + # Content filter error - this is raised by the SDK when finish_reason is "content_filter" + logger.error(f"Content filter error: {e}") + return handle_bad_request_exception( + response_text=str(e), + request=message_piece, + error_code=200, # Content filter with 200 status + is_content_filter=True, ) - except httpx.HTTPStatusError as StatusError: - if StatusError.response.status_code == 400: - # Handle Bad Request - error_response_text = StatusError.response.text - # Content filter errors are handled differently from other 400 errors. - # 400 Bad Request with content_filter error code indicates that the input was filtered - # https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter + except BadRequestError as e: + # Handle Bad Request from the SDK + error_response_text = e.body if hasattr(e, 'body') else str(e) + + # Check if it's a content filter issue + is_content_filter = False + if isinstance(error_response_text, dict): + is_content_filter = error_response_text.get("error", {}).get("code") == "content_filter" + elif isinstance(error_response_text, str): try: json_error = json.loads(error_response_text) is_content_filter = json_error.get("error", {}).get("code") == "content_filter" except json.JSONDecodeError: - # Not valid JSON, set content filter to False - is_content_filter = False - - return handle_bad_request_exception( - response_text=error_response_text, - request=message_piece, - error_code=StatusError.response.status_code, - is_content_filter=is_content_filter, - ) - elif StatusError.response.status_code == 429: + is_content_filter = "content_filter" in error_response_text + + return handle_bad_request_exception( + response_text=str(error_response_text), + request=message_piece, + error_code=400, + is_content_filter=is_content_filter, + ) + except RateLimitError as e: + # SDK's RateLimitError - convert to our exception + logger.warning(f"Rate limit hit: {e}") + raise RateLimitException() + except APIStatusError as e: + # Other API errors + if e.status_code == 429: raise RateLimitException() else: raise - logger.info(f'Received the following response from the prompt target "{str_response.text}"') - response: Message = self._construct_message_from_openai_json( - open_ai_str_response=str_response.text, message_piece=message_piece - ) - - return response + async def _make_chat_completion_request(self, body: dict): + """ + Make the actual chat completion request using the OpenAI SDK. + This method should be overridden by subclasses to use the appropriate SDK method. + + Args: + body (dict): The request body parameters. + + Returns: + The completion response from the OpenAI SDK. + """ + raise NotImplementedError async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + """ + Construct the request body from a conversation. + This method should be overridden by subclasses. + + Args: + conversation (MutableSequence[Message]): The conversation history. + is_json_response (bool): Whether to request JSON response format. + + Returns: + dict: The request body parameters. + """ raise NotImplementedError - def _construct_message_from_openai_json( + def _construct_message_from_completion_response( self, *, - open_ai_str_response: str, + completion_response, message_piece: MessagePiece, ) -> Message: + """ + Construct a Message from the OpenAI SDK completion response. + This method should be overridden by subclasses. + + Args: + completion_response: The completion response from the OpenAI SDK. + message_piece (MessagePiece): The original request message piece. + + Returns: + Message: The constructed message. + """ raise NotImplementedError def is_json_response_supported(self) -> bool: diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index b0405e41d..2cc667ffd 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -118,7 +118,8 @@ def __init__( super().__init__(temperature=temperature, top_p=top_p, **kwargs) self._max_output_tokens = max_output_tokens - response_url_patterns = [r"/responses"] + # Accept both old Azure format (/responses) and new format (/openai/v1) + response_url_patterns = [r"/responses", r"/openai/v1"] self._warn_if_irregular_endpoint(response_url_patterns) # Reasoning parameters are not yet supported by PyRIT. @@ -317,6 +318,44 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} + async def _make_chat_completion_request(self, body: dict): + """ + Make the actual responses request using the OpenAI SDK. + + Args: + body (dict): The request body parameters. + + Returns: + The response from the OpenAI SDK. + """ + # The Responses API is accessed via client.responses.create() + # It returns a different response format than chat completions + return await self._async_client.responses.create(**body) + + def _construct_message_from_completion_response( + self, + *, + completion_response, + message_piece: MessagePiece, + ) -> Message: + """ + Construct a Message from the OpenAI SDK responses response. + + Args: + completion_response: The response from the OpenAI SDK. + message_piece (MessagePiece): The original request message piece. + + Returns: + Message: The constructed message. + """ + # Convert the SDK response to JSON string for processing + # The SDK response object can be converted to dict via model_dump() + response_json = completion_response.model_dump_json() + return self._construct_message_from_openai_json( + open_ai_str_response=response_json, + message_piece=message_piece, + ) + def _construct_message_from_openai_json( self, *, diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 0bd91e77d..66c24a288 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -5,9 +5,11 @@ import logging import re from abc import abstractmethod -from typing import Optional +from typing import Optional, Union from urllib.parse import urlparse +from openai import AsyncOpenAI, AsyncAzureOpenAI + from pyrit.auth import AzureAuth from pyrit.auth.azure_auth import get_default_scope from pyrit.common import default_values @@ -25,6 +27,7 @@ class OpenAITarget(PromptChatTarget): api_key_environment_variable: str _azure_auth: Optional[AzureAuth] = None + _async_client: Optional[Union[AsyncOpenAI, AsyncAzureOpenAI]] = None def __init__( self, @@ -62,6 +65,7 @@ def __init__( """ self._headers: dict = {} self._httpx_client_kwargs = httpx_client_kwargs or {} + self._use_entra_auth = use_entra_auth request_headers = default_values.get_non_required_value( env_var_name=self.ADDITIONAL_REQUEST_HEADERS, passed_value=headers @@ -87,6 +91,7 @@ def __init__( self._api_key = api_key self._set_auth_headers(use_entra_auth=use_entra_auth, passed_api_key=api_key) + self._initialize_openai_client() def _set_auth_headers(self, use_entra_auth, passed_api_key) -> None: if use_entra_auth: @@ -115,6 +120,94 @@ def refresh_auth_headers(self) -> None: if self._azure_auth: self._headers["Authorization"] = f"Bearer {self._azure_auth.refresh_token()}" + def _initialize_openai_client(self) -> None: + """ + Initialize the OpenAI client based on whether it's Azure or standard OpenAI. + + Azure has two formats: + 1. Old format: https://{resource}.openai.azure.com/openai/deployments/{deployment}/...?api-version=... + Uses AsyncAzureOpenAI client + 2. New format: https://{resource}.openai.azure.com/openai/v1?api-version=... + Uses standard AsyncOpenAI client (compatible with OpenAI SDK) + """ + # Determine if this is Azure OpenAI based on the endpoint + is_azure = "azure" in self._endpoint.lower() if self._endpoint else False + + # Check if it's the new Azure format that uses standard OpenAI client + # New format: https://{resource}.openai.azure.com/openai/v1 + is_azure_new_format = False + if is_azure: + import os + from urllib.parse import urlparse + + parsed_url = urlparse(self._endpoint) + # New format has /openai/v1 path + is_azure_new_format = "/openai/v1" in parsed_url.path + + # Merge custom headers with httpx_client_kwargs + httpx_kwargs = self._httpx_client_kwargs.copy() + if self._headers: + httpx_kwargs.setdefault("default_headers", {}).update(self._headers) + + if is_azure and not is_azure_new_format: + # Old Azure format - uses AsyncAzureOpenAI client + # Azure endpoint format: https://{resource}.openai.azure.com/openai/deployments/{deployment}/... + # The endpoint may also include ?api-version=YYYY-MM-DD query parameter + + # Extract API version from query parameter if present + import os + from urllib.parse import urlparse, parse_qs + + parsed_url = urlparse(self._endpoint) + query_params = parse_qs(parsed_url.query) + + # Get api_version from query param, environment variable, or default + if "api-version" in query_params: + api_version = query_params["api-version"][0] + else: + api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview") + + # Azure SDK expects ONLY the base endpoint (scheme://netloc) + # It will automatically append the correct path based on the API being called + # For example: + # - For chat completions: appends /openai/deployments/{deployment}/chat/completions + # - For responses: appends /openai/responses + # So we need to strip any path that's already in the endpoint + azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}" + + # Get the token provider for Entra auth + azure_ad_token_provider = None + if self._use_entra_auth and self._azure_auth: + # Create a token provider function for async operations + async def token_provider(): + return self._azure_auth.refresh_token() + azure_ad_token_provider = token_provider + + self._async_client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=self._api_key if not self._use_entra_auth else None, + azure_ad_token_provider=azure_ad_token_provider, + **httpx_kwargs, + ) + else: + # Standard OpenAI client (used for both platform OpenAI and new Azure format) + # The SDK expects base_url to be the base (e.g., https://api.openai.com/v1) + # For new Azure format: https://{resource}.openai.azure.com/openai/v1 + # If the endpoint includes API-specific paths, we need to strip them because the SDK + # will automatically append the correct path for each API call + base_url = self._endpoint + if base_url.endswith("/chat/completions"): + base_url = base_url[:-len("/chat/completions")] + elif base_url.endswith("/responses"): + base_url = base_url[:-len("/responses")] + + self._async_client = AsyncOpenAI( + base_url=base_url, + api_key=self._api_key, + **httpx_kwargs, + ) + @abstractmethod def _set_openai_env_configuration_vars(self) -> None: """ diff --git a/tests/integration/targets/test_targets_and_secrets.py b/tests/integration/targets/test_targets_and_secrets.py index c3c369f91..6a7b40b27 100644 --- a/tests/integration/targets/test_targets_and_secrets.py +++ b/tests/integration/targets/test_targets_and_secrets.py @@ -140,6 +140,11 @@ async def test_connect_required_openai_text_targets(sqlite_instance, endpoint, a "PLATFORM_OPENAI_RESPONSES_MODEL", ), ("AZURE_OPENAI_RESPONSES_ENDPOINT", "AZURE_OPENAI_RESPONSES_KEY", "AZURE_OPENAI_RESPONSES_MODEL"), + ( + "AZURE_OPENAI_RESPONSES_NEW_FORMAT_ENDPOINT", + "AZURE_OPENAI_RESPONSES_KEY", + "AZURE_OPENAI_RESPONSES_MODEL", + ), ], ) async def test_connect_required_openai_response_targets(sqlite_instance, endpoint, api_key, model_name): diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index c01701a19..8546d7fe7 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -31,6 +31,18 @@ def fake_construct_response_from_request(request, response_text_pieces): return {"dummy": True, "request": request, "response": response_text_pieces} +def create_mock_completion(content: str = "hi", finish_reason: str = "stop"): + """Helper to create a mock OpenAI completion response""" + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].finish_reason = finish_reason + mock_completion.choices[0].message.content = content + mock_completion.model_dump_json.return_value = json.dumps({ + "choices": [{"finish_reason": finish_reason, "message": {"content": content}}] + }) + return mock_completion + + @pytest.fixture def sample_conversations() -> MutableSequence[MessagePiece]: conversations = get_sample_conversations() @@ -277,25 +289,17 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j ] ) # Make assistant response empty - openai_response_json["choices"][0]["message"]["content"] = "" - - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) - with patch( "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - mock_create.return_value = openai_mock_return - target._memory = MagicMock(MemoryInterface) - - with pytest.raises(EmptyResponseException): - await target.send_prompt_async(message=message) + # Mock the OpenAI SDK client to return empty content + mock_completion = create_mock_completion(content="") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + target._memory = MagicMock(MemoryInterface) - assert mock_create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) + with pytest.raises(EmptyResponseException): + await target.send_prompt_async(message=message) @pytest.mark.asyncio @@ -308,21 +312,18 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( target._memory = mock_memory - response = MagicMock() - response.status_code = 429 - - side_effect = httpx.HTTPStatusError("Rate Limit Reached", response=response, request=MagicMock()) + # Create proper mock request and response for RateLimitError + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + mock_response = httpx.Response(429, text="Rate Limit Reached", request=mock_request) + side_effect = RateLimitError("Rate Limit Reached", response=mock_response, body=None) - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=side_effect) - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) - - with pytest.raises(RateLimitException) as rle: - await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") - target._memory.add_message_to_memory.assert_called_once_with(request=message) + message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) - assert str(rle.value) == "Rate Limit Reached" + with pytest.raises(RateLimitException): + await target.send_prompt_async(message=message) @pytest.mark.asyncio @@ -335,19 +336,17 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(target: OpenAI message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) - response = MagicMock() - response.status_code = 400 - response.text = "Some error message" + # Create proper mock request and response for BadRequestError (without content_filter) + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + mock_response = httpx.Response(400, text="Some error message", request=mock_request) + side_effect = BadRequestError("Bad Request", response=mock_response, body="Some error message") - side_effect = httpx.HTTPStatusError("Bad Request", response=response, request=MagicMock()) + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=side_effect) - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - with pytest.raises(httpx.HTTPStatusError) as bre: - await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") - target._memory.add_message_to_memory.assert_called_once_with(request=message) - - assert str(bre.value) == "Bad Request" + # Non-content-filter BadRequestError should be re-raised + with pytest.raises(Exception): # Will raise since handle_bad_request_exception re-raises non-content-filter errors + await target.send_prompt_async(message=message) @pytest.mark.asyncio @@ -385,15 +384,13 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIChatT "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) - mock_create.return_value = openai_mock_return - response: Message = await target.send_prompt_async(message=message) - assert len(response.message_pieces) == 1 - assert response.get_value() == "hi" + # Mock the OpenAI SDK client to return a completion + mock_completion = create_mock_completion(content="hi") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + response: Message = await target.send_prompt_async(message=message) + assert len(response.message_pieces) == 1 + assert response.get_value() == "hi" os.remove(tmp_file_name) @@ -429,24 +426,17 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di ] ) # Make assistant response empty - openai_response_json["choices"][0]["message"]["content"] = "" with patch( "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) - mock_create.return_value = openai_mock_return - target._memory = MagicMock(MemoryInterface) + # Mock the OpenAI SDK client to return empty content + mock_completion = create_mock_completion(content="") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + target._memory = MagicMock(MemoryInterface) - with pytest.raises(EmptyResponseException): - await target.send_prompt_async(message=message) - - assert mock_create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) + with pytest.raises(EmptyResponseException): + await target.send_prompt_async(message=message) @pytest.mark.asyncio @@ -454,51 +444,39 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIChat message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) - response = MagicMock() - response.status_code = 429 + # Create proper mock request and response for RateLimitError + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + mock_response = httpx.Response(429, text="Rate Limit Reached", request=mock_request) + side_effect = RateLimitError("Rate Limit Reached", response=mock_response, body="Rate limit reached") - side_effect = RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=side_effect) - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect - ) as mock_request: - - with pytest.raises(RateLimitError): - await target.send_prompt_async(message=message) - assert mock_request.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") + with pytest.raises(RateLimitException): + await target.send_prompt_async(message=message) @pytest.mark.asyncio async def test_send_prompt_async_bad_request_error(target: OpenAIChatTarget): - response = MagicMock() - response.status_code = 400 - - side_effect = BadRequestError("Bad Request Error", response=response, body="Bad request") + # Create proper mock request and response for BadRequestError (without content_filter) + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + mock_response = httpx.Response(400, text="Bad Request Error", request=mock_request) + side_effect = BadRequestError("Bad Request Error", response=mock_response, body="Bad request") message = Message(message_pieces=[MessagePiece(role="user", conversation_id="1236748", original_value="Hello")]) - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - with pytest.raises(BadRequestError) as bre: - await target.send_prompt_async(message=message) - assert str(bre.value) == "Bad Request Error" + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=side_effect) + + # Non-content-filter BadRequestError should be re-raised + with pytest.raises(Exception): # Will raise since handle_bad_request_exception re-raises non-content-filter errors + await target.send_prompt_async(message=message) @pytest.mark.asyncio async def test_send_prompt_async_content_filter_200(target: OpenAIChatTarget): - response_body = json.dumps( - { - "choices": [ - { - "content_filter_results": {"violence": {"filtered": True, "severity": "medium"}}, - "finish_reason": "content_filter", - "message": {"content": "Offending content omitted since this is just a test.", "role": "assistant"}, - } - ], - } - ) - message = Message( message_pieces=[ MessagePiece( @@ -509,16 +487,17 @@ async def test_send_prompt_async_content_filter_200(target: OpenAIChatTarget): ] ) - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = response_body - - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", return_value=mock_response): - response = await target.send_prompt_async(message=message) - assert len(response.message_pieces) == 1 - assert response.message_pieces[0].response_error == "blocked" - assert response.message_pieces[0].converted_value_data_type == "error" - assert "content_filter_results" in response.get_value() + # Mock the OpenAI SDK client to return content_filter finish_reason + mock_completion = create_mock_completion( + content="Offending content omitted since this is just a test.", + finish_reason="content_filter" + ) + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) + + response = await target.send_prompt_async(message=message) + assert len(response.message_pieces) == 1 + assert response.message_pieces[0].response_error == "blocked" + assert response.message_pieces[0].converted_value_data_type == "error" def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): @@ -650,25 +629,24 @@ async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIChatTa patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} + mock_construct.return_value = {"model": "gpt-4", "messages": [], "stream": False} - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async") as mock_make_request: - mock_make_request.return_value = MagicMock( - text='{"choices": [{"finish_reason": "stop", "message": {"content": "test response"}}]}' - ) + # Mock the OpenAI SDK client + mock_completion = create_mock_completion(content="test response") + target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) - message = Message( - message_pieces=[ - MessagePiece( - role="user", - original_value="test prompt", - converted_value="test prompt", - converted_value_data_type="text", - ) - ] - ) - await target.send_prompt_async(message=message) - mock_refresh.assert_called_once() + message = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + converted_value_data_type="text", + ) + ] + ) + await target.send_prompt_async(message=message) + mock_refresh.assert_called_once() @pytest.mark.asyncio @@ -685,13 +663,13 @@ async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} + mock_construct.return_value = {"model": "gpt-4", "messages": [], "stream": False} + # Create proper mock request and response for BadRequestError with content_filter + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") error_json = {"error": {"code": "content_filter"}} - response = MagicMock() - response.status_code = 400 - response.text = json.dumps(error_json) - status_error = httpx.HTTPStatusError("Bad Request", request=MagicMock(), response=response) + mock_response = httpx.Response(400, text=json.dumps(error_json), request=mock_request) + status_error = BadRequestError("Bad Request", response=mock_response, body=error_json) message_piece = MessagePiece( role="user", @@ -703,18 +681,18 @@ async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): ) message = Message(message_pieces=[message_piece]) - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", AsyncMock(side_effect=status_error) - ) as mock_make_request: - result = await target.send_prompt_async(message=message) - - assert mock_make_request.call_count == 1 - assert result.message_pieces[0].converted_value_data_type == "error" - assert result.message_pieces[0].response_error == "blocked" + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=status_error) + + result = await target.send_prompt_async(message=message) + assert result.message_pieces[0].converted_value_data_type == "error" + assert result.message_pieces[0].response_error == "blocked" @pytest.mark.asyncio async def test_send_prompt_async_other_http_error(monkeypatch): + from openai import APIStatusError + target = OpenAIChatTarget( model_name="gpt-4", endpoint="https://mock.azure.com/", @@ -733,15 +711,15 @@ async def test_send_prompt_async_other_http_error(monkeypatch): target._memory.get_conversation.return_value = [] target.refresh_auth_headers = MagicMock() - response = MagicMock() - response.status_code = 500 - status_error = httpx.HTTPStatusError("Internal Server Error", request=MagicMock(), response=response) + # Create proper mock request and response for APIStatusError + mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + mock_response = httpx.Response(500, text="Internal Server Error", request=mock_request) + status_error = APIStatusError("Internal Server Error", response=mock_response, body=None) - monkeypatch.setattr( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", AsyncMock(side_effect=status_error) - ) + # Mock the OpenAI SDK client method + target._async_client.chat.completions.create = AsyncMock(side_effect=status_error) - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(APIStatusError): await target.send_prompt_async(message=message) @@ -827,18 +805,67 @@ def test_url_validation_no_warning_for_correct_azure_endpoint(caplog, patch_cent assert target -def test_url_validation_no_warning_for_correct_openai_endpoint(caplog, patch_central_database): - """Test that URL validation doesn't warn for correct OpenAI endpoints.""" - with patch.dict(os.environ, {}, clear=True): - with caplog.at_level(logging.WARNING): - target = OpenAIChatTarget( - model_name="gpt-4", - endpoint="https://api.openai.com/v1/chat/completions", - api_key="test-key", - ) - # Should not have URL validation warnings warning_logs = [record for record in caplog.records if record.levelno >= logging.WARNING] endpoint_warnings = [log for log in warning_logs if "The provided endpoint URL" in log.message] assert len(endpoint_warnings) == 0 assert target + + +def test_azure_endpoint_with_api_version_query_param(patch_central_database): + """Test that Azure endpoints with api-version query parameter are handled correctly.""" + with patch.dict(os.environ, {}, clear=True): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15", + api_key="test-key", + ) + + # Verify the SDK client was initialized with the base endpoint and api_version extracted + assert target._async_client is not None + # The AsyncAzureOpenAI client should have been initialized with the base URL (no query params, no path) + # and the api_version as a separate parameter + + +def test_azure_endpoint_new_format_openai_v1(patch_central_database): + """Test that Azure endpoints with /openai/v1 format are handled correctly.""" + with patch.dict(os.environ, {}, clear=True): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://test.openai.azure.com/openai/v1?api-version=2025-03-01-preview", + api_key="test-key", + ) + + # Verify the SDK client was initialized + assert target._async_client is not None + # The AsyncAzureOpenAI client should have been initialized with just the base URL + + +def test_azure_responses_endpoint_format(patch_central_database): + """Test that Azure responses endpoint format is handled correctly.""" + with patch.dict(os.environ, {}, clear=True): + from pyrit.prompt_target import OpenAIResponseTarget + + target = OpenAIResponseTarget( + model_name="o4-mini", + endpoint="https://test.openai.azure.com/openai/responses?api-version=2025-03-01-preview", + api_key="test-key", + ) + + # Verify the SDK client was initialized + assert target._async_client is not None + + +def test_azure_responses_endpoint_new_format(patch_central_database): + """Test that Azure responses endpoint with /openai/v1 format is handled correctly.""" + with patch.dict(os.environ, {}, clear=True): + from pyrit.prompt_target import OpenAIResponseTarget + + target = OpenAIResponseTarget( + model_name="o4-mini", + endpoint="https://test.openai.azure.com/openai/v1?api-version=2025-03-01-preview", + api_key="test-key", + ) + + # Verify the SDK client was initialized + assert target._async_client is not None diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 6d4275ee0..6890b7f50 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -27,6 +27,25 @@ from pyrit.prompt_target import OpenAIResponseTarget, PromptChatTarget +def create_mock_response(response_dict: dict = None) -> MagicMock: + """ + Helper function to create a mock OpenAI SDK response object. + + Args: + response_dict: Optional dictionary to use as response data. + If None, uses default from openai_response_json_dict(). + + Returns: + A mock object that simulates the OpenAI SDK response. + """ + if response_dict is None: + response_dict = openai_response_json_dict() + + mock_response = MagicMock() + mock_response.model_dump_json.return_value = json.dumps(response_dict) + return mock_response + + def fake_construct_response_from_request(request, response_text_pieces): return {"dummy": True, "request": request, "response": response_text_pieces} @@ -267,24 +286,19 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ) # Make assistant response empty openai_response_json["output"][0]["content"][0]["text"] = "" - - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) + mock_response = create_mock_response(openai_response_json) with patch( "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - mock_create.return_value = openai_mock_return - target._memory = MagicMock(MemoryInterface) + target._async_client.responses.create = AsyncMock(return_value=mock_response) + target._memory = MagicMock(MemoryInterface) - with pytest.raises(EmptyResponseException): - await target.send_prompt_async(message=message) + with pytest.raises(EmptyResponseException): + await target.send_prompt_async(message=message) - assert mock_create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) + assert target._async_client.responses.create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) @pytest.mark.asyncio @@ -297,21 +311,19 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( target._memory = mock_memory - response = MagicMock() - response.status_code = 429 - - side_effect = httpx.HTTPStatusError("Rate Limit Reached", response=response, request=MagicMock()) - - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) + message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) - with pytest.raises(RateLimitException) as rle: - await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") - target._memory.add_message_to_memory.assert_called_once_with(request=message) + # Mock the SDK to raise RateLimitError + target._async_client.responses.create = AsyncMock(side_effect=RateLimitError( + "Rate limit exceeded", + response=MagicMock(status_code=429), + body=None + )) - assert str(rle.value) == "Rate Limit Reached" + with pytest.raises(RateLimitException): + await target.send_prompt_async(message=message) + target._memory.get_conversation.assert_called_once_with(conversation_id="123") + target._memory.add_message_to_memory.assert_called_once_with(request=message) @pytest.mark.asyncio @@ -324,19 +336,17 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(target: OpenAI message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) - response = MagicMock() - response.status_code = 400 - response.text = "Some error text" + # Mock the SDK to raise BadRequestError (non-content-filter) + target._async_client.responses.create = AsyncMock(side_effect=BadRequestError( + "Bad request", + response=MagicMock(status_code=400), + body={"error": {"message": "Invalid request"}} + )) - side_effect = httpx.HTTPStatusError("Bad Request", response=response, request=MagicMock()) - - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - with pytest.raises(httpx.HTTPStatusError) as bre: - await target.send_prompt_async(message=message) - target._memory.get_conversation.assert_called_once_with(conversation_id="123") - target._memory.add_message_to_memory.assert_called_once_with(request=message) - - assert str(bre.value) == "Bad Request" + with pytest.raises(BadRequestError): + await target.send_prompt_async(message=message) + target._memory.get_conversation.assert_called_once_with(conversation_id="123") + target._memory.add_message_to_memory.assert_called_once_with(request=message) @pytest.mark.asyncio @@ -370,19 +380,16 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo ), ] ) + mock_response = create_mock_response(openai_response_json) + with patch( "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) - mock_create.return_value = openai_mock_return - response: Message = await target.send_prompt_async(message=message) - assert len(response.message_pieces) == 1 - assert response.get_value() == "hi" + target._async_client.responses.create = AsyncMock(return_value=mock_response) + response: Message = await target.send_prompt_async(message=message) + assert len(response.message_pieces) == 1 + assert response.get_value() == "hi" os.remove(tmp_file_name) @@ -419,23 +426,19 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di ) # Make assistant response empty openai_response_json["output"][0]["content"][0]["text"] = "" + mock_response = create_mock_response(openai_response_json) + with patch( "pyrit.common.data_url_converter.convert_local_image_to_data_url", return_value="data:image/jpeg;base64,encoded_string", ): - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock - ) as mock_create: - - openai_mock_return = MagicMock() - openai_mock_return.text = json.dumps(openai_response_json) - mock_create.return_value = openai_mock_return - target._memory = MagicMock(MemoryInterface) + target._async_client.responses.create = AsyncMock(return_value=mock_response) + target._memory = MagicMock(MemoryInterface) - with pytest.raises(EmptyResponseException): - await target.send_prompt_async(message=message) + with pytest.raises(EmptyResponseException): + await target.send_prompt_async(message=message) - assert mock_create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) + assert target._async_client.responses.create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) @pytest.mark.asyncio @@ -443,51 +446,39 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIResp message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) - response = MagicMock() - response.status_code = 429 - - side_effect = RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") - - with patch( - "pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect - ) as mock_request: + # Mock SDK to raise RateLimitError + target._async_client.responses.create = AsyncMock(side_effect=RateLimitError( + "Rate limit exceeded", + response=MagicMock(status_code=429), + body="Rate limit reached" + )) - with pytest.raises(RateLimitError): - await target.send_prompt_async(message=message) - assert mock_request.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") + # Our code converts RateLimitError to RateLimitException, which has retry logic + with pytest.raises(RateLimitException): + await target.send_prompt_async(message=message) + # The retry decorator will call it multiple times before giving up + assert target._async_client.responses.create.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) @pytest.mark.asyncio async def test_send_prompt_async_bad_request_error(target: OpenAIResponseTarget): - response = MagicMock() - response.status_code = 400 - - side_effect = BadRequestError("Bad Request Error", response=response, body="Bad request") - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="1236748", original_value="Hello")]) - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - with pytest.raises(BadRequestError) as bre: - await target.send_prompt_async(message=message) - assert str(bre.value) == "Bad Request Error" + # Mock SDK to raise BadRequestError + target._async_client.responses.create = AsyncMock(side_effect=BadRequestError( + "Bad request", + response=MagicMock(status_code=400), + body="Bad request" + )) + + with pytest.raises(BadRequestError): + await target.send_prompt_async(message=message) @pytest.mark.asyncio async def test_send_prompt_async_content_filter(target: OpenAIResponseTarget): - response_body = json.dumps( - { - "error": { - "code": "content_filter", - "innererror": { - "code": "ResponsibleAIPolicyViolation", - "content_filter_result": {"violence": {"filtered": True, "severity": "medium"}}, - }, - } - } - ) - message = Message( message_pieces=[ MessagePiece( @@ -498,16 +489,28 @@ async def test_send_prompt_async_content_filter(target: OpenAIResponseTarget): ] ) - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.text = response_body + # Create a response with content filter error in the status field + content_filter_response = { + "id": "resp_123", + "object": "response", + "status": None, + "error": { + "code": "content_filter", + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": {"violence": {"filtered": True, "severity": "medium"}}, + }, + }, + "model": "o4-mini", + } + mock_response = create_mock_response(content_filter_response) + target._async_client.responses.create = AsyncMock(return_value=mock_response) - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", return_value=mock_response): - response = await target.send_prompt_async(message=message) - assert len(response.message_pieces) == 1 - assert response.message_pieces[0].response_error == "blocked" - assert response.message_pieces[0].converted_value_data_type == "error" - assert "content_filter_result" in response.get_value() + response = await target.send_prompt_async(message=message) + assert len(response.message_pieces) == 1 + assert response.message_pieces[0].response_error == "blocked" + assert response.message_pieces[0].converted_value_data_type == "error" + assert "content_filter_result" in response.get_value() def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): @@ -614,29 +617,28 @@ async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIRespon target._azure_auth = MagicMock() target._memory = mock_memory + mock_response = create_mock_response(openai_response_json) + target._async_client.responses.create = AsyncMock(return_value=mock_response) + with ( patch.object(target, "refresh_auth_headers") as mock_refresh, patch.object(target, "_validate_request"), patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} - with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async") as mock_make_request: - mock_make_request.return_value = MagicMock(text=json.dumps(openai_response_json)) - - message = Message( - message_pieces=[ - MessagePiece( - role="user", - original_value="test prompt", - converted_value="test prompt", - converted_value_data_type="text", - ) - ] - ) - await target.send_prompt_async(message=message) - mock_refresh.assert_called_once() + message = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + converted_value_data_type="text", + ) + ] + ) + await target.send_prompt_async(message=message) + mock_refresh.assert_called_once() def test_construct_message_from_openai_json_invalid_json( From 997041a12cb47d6c3bd9c6e4cfd78b0da8754472 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 19 Nov 2025 05:54:58 -0800 Subject: [PATCH 2/3] migrate to openai SDK with chat and response targets --- pyrit/auth/azure_auth.py | 27 +++++- .../openai/openai_chat_target_base.py | 1 - pyrit/prompt_target/openai/openai_target.py | 92 ++++++++++++++----- .../targets/test_entra_auth_targets.py | 38 ++++---- .../targets/test_targets_and_secrets.py | 1 + tests/unit/target/test_openai_chat_target.py | 36 -------- .../target/test_openai_response_target.py | 34 ------- 7 files changed, 117 insertions(+), 112 deletions(-) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index d3e6c146f..bdd5c19ad 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -15,6 +15,10 @@ ManagedIdentityCredential, get_bearer_token_provider, ) +from azure.identity.aio import ( + DefaultAzureCredential as AsyncDefaultAzureCredential, + get_bearer_token_provider as get_async_bearer_token_provider, +) from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC from pyrit.auth.authenticator import Authenticator @@ -143,7 +147,7 @@ def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[], Connect to an AOAI endpoint via default Azure credential. Returns: - Authentication token provider + Authentication token provider (synchronous) """ try: token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope) @@ -153,6 +157,27 @@ def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[], raise +def get_async_token_provider_from_default_azure_credential(scope: str) -> Callable[[], str]: + """ + Connect to an AOAI endpoint via default Azure credential with async support. + + This returns an async callable that can be awaited, suitable for use with + async clients like OpenAI's AsyncOpenAI. + + Args: + scope (str): The scope to request tokens for. + + Returns: + Authentication token provider (async callable) + """ + try: + token_provider = get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope) + return token_provider + except Exception as e: + logger.error(f"Failed to obtain async token provider for '{scope}': {e}") + raise + + def get_default_scope(endpoint: str) -> str: """ Get the default scope for the given endpoint. diff --git a/pyrit/prompt_target/openai/openai_chat_target_base.py b/pyrit/prompt_target/openai/openai_chat_target_base.py index 96272e4d9..73e4ed49c 100644 --- a/pyrit/prompt_target/openai/openai_chat_target_base.py +++ b/pyrit/prompt_target/openai/openai_chat_target_base.py @@ -110,7 +110,6 @@ async def send_prompt_async(self, *, message: Message) -> Message: Message: The updated conversation entry with the response from the prompt target. """ self._validate_request(message=message) - self.refresh_auth_headers() message_piece: MessagePiece = message.message_pieces[0] diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 66c24a288..0db061659 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -11,7 +11,7 @@ from openai import AsyncOpenAI, AsyncAzureOpenAI from pyrit.auth import AzureAuth -from pyrit.auth.azure_auth import get_default_scope +from pyrit.auth.azure_auth import get_async_token_provider_from_default_azure_credential, get_default_scope from pyrit.common import default_values from pyrit.prompt_target import PromptChatTarget @@ -82,6 +82,12 @@ def __init__( endpoint_value = default_values.get_required_value( env_var_name=self.endpoint_environment_variable, passed_value=endpoint ) + + # For Azure endpoints with deployment in URL, extract it if model_name not provided + if not self._model_name and "azure" in endpoint_value.lower(): + extracted = self._extract_deployment_from_azure_url(endpoint_value) + if extracted: # Only use extracted deployment if we actually found one + self._model_name = extracted # Initialize parent with endpoint and model_name PromptChatTarget.__init__( @@ -100,18 +106,42 @@ def _set_auth_headers(self, use_entra_auth, passed_api_key) -> None: logger.info("Authenticating with AzureAuth") scope = get_default_scope(self._endpoint) self._azure_auth = AzureAuth(token_scope=scope) + # For SDK-based targets: auth is handled via azure_ad_token_provider parameter + # For non-SDK targets (DALL-E, TTS, etc): keep manual header for backward compatibility self._headers["Authorization"] = f"Bearer {self._azure_auth.get_token()}" self._api_key = None else: self._api_key = default_values.get_non_required_value( env_var_name=self.api_key_environment_variable, passed_value=passed_api_key ) - # This header is set as api-key in azure and bearer in openai - # But azure still functions if it's in both places and in fact, - # in Azure foundry it needs to be set as a bearer + # For SDK-based targets: api_key is passed to AsyncOpenAI/AsyncAzureOpenAI constructors + # For non-SDK targets (DALL-E, TTS, etc): keep manual headers for backward compatibility self._headers["Api-Key"] = self._api_key self._headers["Authorization"] = f"Bearer {self._api_key}" + def _extract_deployment_from_azure_url(self, url: str) -> str: + """ + Extract deployment/model name from Azure OpenAI URL. + + Azure URLs have formats like: + - https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions + - https://{resource}.openai.azure.com/openai/deployments/{deployment}/responses + + Args: + url: The Azure endpoint URL. + + Returns: + The deployment name, or empty string if not found. + """ + # Match /deployments/{deployment_name}/ + match = re.search(r'/deployments/([^/]+)/', url) + if match: + deployment = match.group(1) + logger.info(f"Extracted deployment name from URL: {deployment}") + return deployment + + return "" + def refresh_auth_headers(self) -> None: """ Refresh the authentication headers. This is particularly useful for Entra authentication @@ -124,39 +154,43 @@ def _initialize_openai_client(self) -> None: """ Initialize the OpenAI client based on whether it's Azure or standard OpenAI. - Azure has two formats: + Azure has multiple endpoint formats: 1. Old format: https://{resource}.openai.azure.com/openai/deployments/{deployment}/...?api-version=... - Uses AsyncAzureOpenAI client - 2. New format: https://{resource}.openai.azure.com/openai/v1?api-version=... - Uses standard AsyncOpenAI client (compatible with OpenAI SDK) + Uses AsyncAzureOpenAI client with api-version parameter (supports both API key and Entra auth) + 2. New format: https://{resource}.openai.azure.com/openai/v1 OR https://{resource}.models.ai.azure.com/... + Uses standard AsyncOpenAI client (no api-version needed) + - With API key: pass api_key parameter + - With Entra: pass token provider function as api_key parameter """ # Determine if this is Azure OpenAI based on the endpoint is_azure = "azure" in self._endpoint.lower() if self._endpoint else False - # Check if it's the new Azure format that uses standard OpenAI client - # New format: https://{resource}.openai.azure.com/openai/v1 + # Check if it's the new Azure format (OpenAI-compatible) + # New format includes: + # - https://{resource}.openai.azure.com/openai/v1 + # - https://{resource}.models.ai.azure.com/... (Azure Foundry endpoints) is_azure_new_format = False if is_azure: - import os from urllib.parse import urlparse parsed_url = urlparse(self._endpoint) - # New format has /openai/v1 path - is_azure_new_format = "/openai/v1" in parsed_url.path + # New format has /openai/v1 path OR uses models.ai.azure.com domain + is_azure_new_format = "/openai/v1" in parsed_url.path or ".models.ai.azure.com" in parsed_url.netloc # Merge custom headers with httpx_client_kwargs httpx_kwargs = self._httpx_client_kwargs.copy() if self._headers: httpx_kwargs.setdefault("default_headers", {}).update(self._headers) + # Only old Azure format uses AsyncAzureOpenAI + # Everything else (platform OpenAI, new Azure format, Azure Foundry) uses standard AsyncOpenAI if is_azure and not is_azure_new_format: # Old Azure format - uses AsyncAzureOpenAI client - # Azure endpoint format: https://{resource}.openai.azure.com/openai/deployments/{deployment}/... - # The endpoint may also include ?api-version=YYYY-MM-DD query parameter + # Endpoint format: https://{resource}.openai.azure.com/openai/deployments/{deployment}/... # Extract API version from query parameter if present import os - from urllib.parse import urlparse, parse_qs + from urllib.parse import parse_qs parsed_url = urlparse(self._endpoint) query_params = parse_qs(parsed_url.query) @@ -168,11 +202,7 @@ def _initialize_openai_client(self) -> None: api_version = os.environ.get("OPENAI_API_VERSION", "2024-02-15-preview") # Azure SDK expects ONLY the base endpoint (scheme://netloc) - # It will automatically append the correct path based on the API being called - # For example: - # - For chat completions: appends /openai/deployments/{deployment}/chat/completions - # - For responses: appends /openai/responses - # So we need to strip any path that's already in the endpoint + # It will automatically append paths like /openai/deployments/{deployment}/chat/completions azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}" # Get the token provider for Entra auth @@ -191,9 +221,10 @@ async def token_provider(): **httpx_kwargs, ) else: - # Standard OpenAI client (used for both platform OpenAI and new Azure format) + # Standard OpenAI client (used for platform OpenAI, new Azure format, and Azure Foundry) # The SDK expects base_url to be the base (e.g., https://api.openai.com/v1) # For new Azure format: https://{resource}.openai.azure.com/openai/v1 + # For Azure Foundry: https://{resource}.models.ai.azure.com/v1 # If the endpoint includes API-specific paths, we need to strip them because the SDK # will automatically append the correct path for each API call base_url = self._endpoint @@ -202,9 +233,24 @@ async def token_provider(): elif base_url.endswith("/responses"): base_url = base_url[:-len("/responses")] + # For Azure Foundry endpoints (*.models.ai.azure.com), ensure they end with /v1 + from urllib.parse import urlparse + parsed = urlparse(base_url) + if ".models.ai.azure.com" in parsed.netloc and not base_url.endswith("/v1"): + base_url = base_url.rstrip("/") + "/v1" + + # For new Azure format with Entra auth, pass token provider as api_key + api_key_value = self._api_key + if is_azure_new_format and self._use_entra_auth and self._azure_auth: + # Token provider callable that the SDK will call to get bearer tokens + # Use the Azure SDK's async get_bearer_token_provider for proper token management + # This returns an async callable that the OpenAI SDK can await natively + scope = get_default_scope(self._endpoint) + api_key_value = get_async_token_provider_from_default_azure_credential(scope) + self._async_client = AsyncOpenAI( base_url=base_url, - api_key=self._api_key, + api_key=api_key_value, **httpx_kwargs, ) diff --git a/tests/integration/targets/test_entra_auth_targets.py b/tests/integration/targets/test_entra_auth_targets.py index e9ce4fbe2..0f3f1af29 100644 --- a/tests/integration/targets/test_entra_auth_targets.py +++ b/tests/integration/targets/test_entra_auth_targets.py @@ -19,33 +19,36 @@ @pytest.mark.asyncio @pytest.mark.parametrize( - ("endpoint", "model_name"), + ("endpoint", "model_name", "supports_seed"), [ - ("AZURE_OPENAI_GPT4O_ENDPOINT", ""), - ("AZURE_OPENAI_GPT4O_ENDPOINT2", ""), - ("AZURE_OPENAI_GPT4O_AAD_ENDPOINT", ""), - ("AZURE_OPENAI_GPT4O_UNSAFE_ENDPOINT", ""), - ("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", ""), - ("AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT", ""), - ("AZURE_OPENAI_GPT4O_STRICT_FILTER_ENDPOINT", ""), - ("AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT", ""), - ("AZURE_OPENAI_GPT4_CHAT_ENDPOINT", ""), - ("AZURE_OPENAI_GPTV_CHAT_ENDPOINT", ""), - ("AZURE_FOUNDRY_DEEPSEEK_ENDPOINT", ""), - ("AZURE_FOUNDRY_PHI4_ENDPOINT", ""), - ("AZURE_FOUNDRY_MINSTRAL3B_ENDPOINT", ""), - ("XPIAI_OPENAI_GPT4O_CHAT_ENDPOINT", "XPIA_OPENAI_MODEL"), + ("AZURE_OPENAI_GPT4O_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT4O_NEW_FORMAT_ENDPOINT", "AZURE_OPENAI_GPT4O_MODEL", True), + ("AZURE_OPENAI_GPT4O_ENDPOINT2", "", True), + ("AZURE_OPENAI_GPT4O_AAD_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT4O_UNSAFE_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "", True), + ("AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT4O_STRICT_FILTER_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT", "", True), + ("AZURE_OPENAI_GPT4_CHAT_ENDPOINT", "", True), + ("AZURE_OPENAI_GPTV_CHAT_ENDPOINT", "", True), + ("AZURE_FOUNDRY_DEEPSEEK_ENDPOINT", "", True), + ("AZURE_FOUNDRY_PHI4_ENDPOINT", "", True), + ("AZURE_FOUNDRY_MINSTRAL3B_ENDPOINT", "", False), + ("XPIA_OPENAI_GPT4O_ENDPOINT", "XPIA_OPENAI_MODEL", True), ], ) -async def test_openai_chat_target_entra_auth(sqlite_instance, endpoint, model_name): +async def test_openai_chat_target_entra_auth(sqlite_instance, endpoint, model_name, supports_seed): args = { "endpoint": os.getenv(endpoint), "temperature": 0.0, - "seed": 42, "use_entra_auth": True, "model_name": os.getenv(model_name), } + if supports_seed: + args["seed"] = 42 + # These endpoints should have Entra authentication enabled in the current context # e.g. Cognitive Services OpenAI Contributor or Cognitive Services User/Contributor role (for non-OpenAI resources) target = OpenAIChatTarget(**args) @@ -105,6 +108,7 @@ async def test_openai_tts_target_entra_auth(sqlite_instance, endpoint): ("endpoint", "model_name"), [ ("OPENAI_RESPONSES_ENDPOINT", "OPENAI_RESPONSES_MODEL"), + ("AZURE_OPENAI_RESPONSES_NEW_FORMAT_ENDPOINT", "AZURE_OPENAI_RESPONSES_MODEL"), ], ) async def test_openai_responses_target_entra_auth(sqlite_instance, endpoint, model_name): diff --git a/tests/integration/targets/test_targets_and_secrets.py b/tests/integration/targets/test_targets_and_secrets.py index 6a7b40b27..5c042b21e 100644 --- a/tests/integration/targets/test_targets_and_secrets.py +++ b/tests/integration/targets/test_targets_and_secrets.py @@ -97,6 +97,7 @@ async def _assert_can_send_video_prompt(target): ("OPENAI_CHAT_ENDPOINT", "OPENAI_CHAT_KEY", "OPENAI_CHAT_MODEL", True), ("PLATFORM_OPENAI_CHAT_ENDPOINT", "PLATFORM_OPENAI_CHAT_KEY", "PLATFORM_OPENAI_CHAT_MODEL", True), ("AZURE_OPENAI_GPT4O_ENDPOINT", "AZURE_OPENAI_GPT4O_KEY", "", True), + ("AZURE_OPENAI_GPT4O_NEW_FORMAT_ENDPOINT", "AZURE_OPENAI_GPT4O_KEY", "AZURE_OPENAI_GPT4O_MODEL", True), ("AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT", "AZURE_OPENAI_INTEGRATION_TEST_KEY", "", True), ("AZURE_OPENAI_GPT4O_UNSAFE_ENDPOINT", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", "", True), ("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", "", True), diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 8546d7fe7..9d593f8b1 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -613,42 +613,6 @@ def test_construct_message_unknown_finish_reason(target: OpenAIChatTarget, dummy assert "Unknown finish_reason" in str(excinfo.value) -@pytest.mark.asyncio -@pytest.mark.asyncio -async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIChatTarget): - mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] - mock_memory.add_message_to_memory = AsyncMock() - - target._azure_auth = MagicMock() - target._memory = mock_memory - - with ( - patch.object(target, "refresh_auth_headers") as mock_refresh, - patch.object(target, "_validate_request"), - patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, - ): - - mock_construct.return_value = {"model": "gpt-4", "messages": [], "stream": False} - - # Mock the OpenAI SDK client - mock_completion = create_mock_completion(content="test response") - target._async_client.chat.completions.create = AsyncMock(return_value=mock_completion) - - message = Message( - message_pieces=[ - MessagePiece( - role="user", - original_value="test prompt", - converted_value="test prompt", - converted_value_data_type="text", - ) - ] - ) - await target.send_prompt_async(message=message) - mock_refresh.assert_called_once() - - @pytest.mark.asyncio async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): mock_memory = MagicMock(spec=MemoryInterface) diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 6890b7f50..8a11fe469 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -607,40 +607,6 @@ def test_construct_message_empty_response( assert "The chat returned an empty response." in str(excinfo.value) -@pytest.mark.asyncio -@pytest.mark.asyncio -async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIResponseTarget, openai_response_json: dict): - mock_memory = MagicMock(spec=MemoryInterface) - mock_memory.get_conversation.return_value = [] - mock_memory.add_message_to_memory = AsyncMock() - - target._azure_auth = MagicMock() - target._memory = mock_memory - - mock_response = create_mock_response(openai_response_json) - target._async_client.responses.create = AsyncMock(return_value=mock_response) - - with ( - patch.object(target, "refresh_auth_headers") as mock_refresh, - patch.object(target, "_validate_request"), - patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, - ): - mock_construct.return_value = {} - - message = Message( - message_pieces=[ - MessagePiece( - role="user", - original_value="test prompt", - converted_value="test prompt", - converted_value_data_type="text", - ) - ] - ) - await target.send_prompt_async(message=message) - mock_refresh.assert_called_once() - - def test_construct_message_from_openai_json_invalid_json( target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece ): From cc3f8e8c3bcd7585890f5c4ead49cb2a4c9e9348 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 19 Nov 2025 05:58:26 -0800 Subject: [PATCH 3/3] remove unnecessary file --- frontend/.vite/deps_temp_f8b3b81a/package.json | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 frontend/.vite/deps_temp_f8b3b81a/package.json diff --git a/frontend/.vite/deps_temp_f8b3b81a/package.json b/frontend/.vite/deps_temp_f8b3b81a/package.json deleted file mode 100644 index 3dbc1ca59..000000000 --- a/frontend/.vite/deps_temp_f8b3b81a/package.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "type": "module" -}