Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions frontend/.vite/deps_temp_f8b3b81a/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"type": "module"
}
27 changes: 26 additions & 1 deletion pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
ManagedIdentityCredential,
get_bearer_token_provider,
)
from azure.identity.aio import (
DefaultAzureCredential as AsyncDefaultAzureCredential,
get_bearer_token_provider as get_async_bearer_token_provider,
)

from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC
from pyrit.auth.authenticator import Authenticator
Expand Down Expand Up @@ -143,7 +147,7 @@ def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[],
Connect to an AOAI endpoint via default Azure credential.

Returns:
Authentication token provider
Authentication token provider (synchronous)
"""
try:
token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope)
Expand All @@ -153,6 +157,27 @@ def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[],
raise


def get_async_token_provider_from_default_azure_credential(scope: str) -> Callable[[], str]:
"""
Connect to an AOAI endpoint via default Azure credential with async support.

This returns an async callable that can be awaited, suitable for use with
async clients like OpenAI's AsyncOpenAI.

Args:
scope (str): The scope to request tokens for.

Returns:
Authentication token provider (async callable)
"""
try:
token_provider = get_async_bearer_token_provider(AsyncDefaultAzureCredential(), scope)
return token_provider
except Exception as e:
logger.error(f"Failed to obtain async token provider for '{scope}': {e}")
raise


def get_default_scope(endpoint: str) -> str:
"""
Get the default scope for the given endpoint.
Expand Down
71 changes: 70 additions & 1 deletion pyrit/prompt_target/openai/openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,82 @@ async def _construct_request_body(self, conversation: MutableSequence[Message],
# Filter out None values
return {k: v for k, v in body_parameters.items() if v is not None}

async def _make_chat_completion_request(self, body: dict):
"""
Make the actual chat completion request using the OpenAI SDK.

Args:
body (dict): The request body parameters.

Returns:
The completion response from the OpenAI SDK.
"""
# Use the OpenAI SDK client to make the request
completion = await self._async_client.chat.completions.create(**body)
return completion

def _construct_message_from_completion_response(
self,
*,
completion_response,
message_piece: MessagePiece,
) -> Message:
"""
Construct a Message from the OpenAI SDK completion response.

Args:
completion_response: The completion response from the OpenAI SDK (ChatCompletion object).
message_piece (MessagePiece): The original request message piece.

Returns:
Message: The constructed message.
"""
# Extract the finish reason and content from the SDK response
if not completion_response.choices:
raise PyritException(message="No choices returned in the completion response.")

choice = completion_response.choices[0]
finish_reason = choice.finish_reason
extracted_response: str = ""

# finish_reason="stop" means API returned complete message and
# "length" means API returned incomplete message due to max_tokens limit.
if finish_reason in ["stop", "length"]:
extracted_response = choice.message.content or ""

# Handle empty response
if not extracted_response:
logger.error("The chat returned an empty response.")
raise EmptyResponseException(message="The chat returned an empty response.")
elif finish_reason == "content_filter":
# Content filter with status 200 indicates that the model output was filtered
# https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter
# Note: The SDK should raise ContentFilterFinishReasonError for this case,
# but we handle it here as a fallback
return handle_bad_request_exception(
response_text=completion_response.model_dump_json(),
request=message_piece,
error_code=200,
is_content_filter=True
)
else:
raise PyritException(
message=f"Unknown finish_reason {finish_reason} from response: {completion_response.model_dump_json()}"
)

return construct_response_from_request(request=message_piece, response_text_pieces=[extracted_response])

def _construct_message_from_openai_json(
self,
*,
open_ai_str_response: str,
message_piece: MessagePiece,
) -> Message:

"""
Legacy method for backward compatibility.
Parses a JSON string response from OpenAI API.
The SDK-based implementation uses _construct_message_from_completion_response instead.
"""
try:
response = json.loads(open_ai_str_response)
except json.JSONDecodeError as e:
Expand Down
116 changes: 82 additions & 34 deletions pyrit/prompt_target/openai/openai_chat_target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import logging
from typing import Any, MutableSequence, Optional

import httpx
from openai import (
BadRequestError,
RateLimitError,
ContentFilterFinishReasonError,
APIStatusError,
)

from pyrit.common import net_utility
from pyrit.exceptions import (
PyritException,
handle_bad_request_exception,
Expand Down Expand Up @@ -106,7 +110,6 @@ async def send_prompt_async(self, *, message: Message) -> Message:
Message: The updated conversation entry with the response from the prompt target.
"""
self._validate_request(message=message)
self.refresh_auth_headers()

message_piece: MessagePiece = message.message_pieces[0]

Expand All @@ -120,54 +123,99 @@ async def send_prompt_async(self, *, message: Message) -> Message:
body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response)

try:
str_response: httpx.Response = await net_utility.make_request_and_raise_if_error_async(
endpoint_uri=self._endpoint,
method="POST",
headers=self._headers,
request_body=body,
**self._httpx_client_kwargs,
# Use the OpenAI SDK for making the request
response = await self._make_chat_completion_request(body)

# Convert the SDK response to our Message format
return self._construct_message_from_completion_response(
completion_response=response, message_piece=message_piece
)

except ContentFilterFinishReasonError as e:
# Content filter error - this is raised by the SDK when finish_reason is "content_filter"
logger.error(f"Content filter error: {e}")
return handle_bad_request_exception(
response_text=str(e),
request=message_piece,
error_code=200, # Content filter with 200 status
is_content_filter=True,
)
except httpx.HTTPStatusError as StatusError:
if StatusError.response.status_code == 400:
# Handle Bad Request
error_response_text = StatusError.response.text
# Content filter errors are handled differently from other 400 errors.
# 400 Bad Request with content_filter error code indicates that the input was filtered
# https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter
except BadRequestError as e:
# Handle Bad Request from the SDK
error_response_text = e.body if hasattr(e, 'body') else str(e)

# Check if it's a content filter issue
is_content_filter = False
if isinstance(error_response_text, dict):
is_content_filter = error_response_text.get("error", {}).get("code") == "content_filter"
elif isinstance(error_response_text, str):
try:
json_error = json.loads(error_response_text)
is_content_filter = json_error.get("error", {}).get("code") == "content_filter"
except json.JSONDecodeError:
# Not valid JSON, set content filter to False
is_content_filter = False

return handle_bad_request_exception(
response_text=error_response_text,
request=message_piece,
error_code=StatusError.response.status_code,
is_content_filter=is_content_filter,
)
elif StatusError.response.status_code == 429:
is_content_filter = "content_filter" in error_response_text

return handle_bad_request_exception(
response_text=str(error_response_text),
request=message_piece,
error_code=400,
is_content_filter=is_content_filter,
)
except RateLimitError as e:
# SDK's RateLimitError - convert to our exception
logger.warning(f"Rate limit hit: {e}")
raise RateLimitException()
except APIStatusError as e:
# Other API errors
if e.status_code == 429:
raise RateLimitException()
else:
raise

logger.info(f'Received the following response from the prompt target "{str_response.text}"')
response: Message = self._construct_message_from_openai_json(
open_ai_str_response=str_response.text, message_piece=message_piece
)

return response
async def _make_chat_completion_request(self, body: dict):
"""
Make the actual chat completion request using the OpenAI SDK.
This method should be overridden by subclasses to use the appropriate SDK method.

Args:
body (dict): The request body parameters.

Returns:
The completion response from the OpenAI SDK.
"""
raise NotImplementedError

async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict:
"""
Construct the request body from a conversation.
This method should be overridden by subclasses.

Args:
conversation (MutableSequence[Message]): The conversation history.
is_json_response (bool): Whether to request JSON response format.

Returns:
dict: The request body parameters.
"""
raise NotImplementedError

def _construct_message_from_openai_json(
def _construct_message_from_completion_response(
self,
*,
open_ai_str_response: str,
completion_response,
message_piece: MessagePiece,
) -> Message:
"""
Construct a Message from the OpenAI SDK completion response.
This method should be overridden by subclasses.

Args:
completion_response: The completion response from the OpenAI SDK.
message_piece (MessagePiece): The original request message piece.

Returns:
Message: The constructed message.
"""
raise NotImplementedError

def is_json_response_supported(self) -> bool:
Expand Down
41 changes: 40 additions & 1 deletion pyrit/prompt_target/openai/openai_response_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def __init__(
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
self._max_output_tokens = max_output_tokens

response_url_patterns = [r"/responses"]
# Accept both old Azure format (/responses) and new format (/openai/v1)
response_url_patterns = [r"/responses", r"/openai/v1"]
self._warn_if_irregular_endpoint(response_url_patterns)

# Reasoning parameters are not yet supported by PyRIT.
Expand Down Expand Up @@ -317,6 +318,44 @@ async def _construct_request_body(self, conversation: MutableSequence[Message],
# Filter out None values
return {k: v for k, v in body_parameters.items() if v is not None}

async def _make_chat_completion_request(self, body: dict):
"""
Make the actual responses request using the OpenAI SDK.

Args:
body (dict): The request body parameters.

Returns:
The response from the OpenAI SDK.
"""
# The Responses API is accessed via client.responses.create()
# It returns a different response format than chat completions
return await self._async_client.responses.create(**body)

def _construct_message_from_completion_response(
self,
*,
completion_response,
message_piece: MessagePiece,
) -> Message:
"""
Construct a Message from the OpenAI SDK responses response.

Args:
completion_response: The response from the OpenAI SDK.
message_piece (MessagePiece): The original request message piece.

Returns:
Message: The constructed message.
"""
# Convert the SDK response to JSON string for processing
# The SDK response object can be converted to dict via model_dump()
response_json = completion_response.model_dump_json()
return self._construct_message_from_openai_json(
open_ai_str_response=response_json,
message_piece=message_piece,
)

def _construct_message_from_openai_json(
self,
*,
Expand Down
Loading
Loading