Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,6 @@ select = [
"D", # pydocstyle
]
isort = { combine-as-imports = true, known-first-party = ["guardrails"] }
extend-ignore=[
"D100", # Missing docstring in public module
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D107", # Missing docstring in `__init__`
]

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
54 changes: 34 additions & 20 deletions src/guardrails/resources/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,56 @@ class Chat:
"""Chat completions with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize Chat resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

@property
def completions(self):
"""Access chat completions API with guardrails.

Returns:
ChatCompletions: Chat completions interface with guardrail support.
"""
return ChatCompletions(self._client)


class AsyncChat:
"""Chat completions with guardrails (async)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize AsyncChat resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

@property
def completions(self):
"""Access async chat completions API with guardrails.

Returns:
AsyncChatCompletions: Async chat completions with guardrail support.
"""
return AsyncChatCompletions(self._client)


class ChatCompletions:
"""Chat completions interface with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize ChatCompletions interface.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

def create(
self,
messages: list[dict[str, str]],
model: str,
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
):
def create(self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs):
"""Create chat completion with guardrails (synchronous).

Runs preflight first, then executes input guardrails concurrently with the LLM call.
Expand All @@ -59,9 +77,7 @@ def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_messages = self._client._apply_preflight_modifications(
messages, preflight_results
)
modified_messages = self._client._apply_preflight_modifications(messages, preflight_results)

# Run input guardrails and LLM call concurrently using a thread for the LLM
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -102,15 +118,15 @@ class AsyncChatCompletions:
"""Async chat completions interface with guardrails."""

def __init__(self, client):
"""Initialize AsyncChatCompletions interface.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

async def create(
self,
messages: list[dict[str, str]],
model: str,
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs
) -> Any | AsyncIterator[Any]:
"""Create chat completion with guardrails."""
latest_message, _ = self._client._extract_latest_user_message(messages)
Expand All @@ -124,9 +140,7 @@ async def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_messages = self._client._apply_preflight_modifications(
messages, preflight_results
)
modified_messages = self._client._apply_preflight_modifications(messages, preflight_results)

# Run input guardrails and LLM call concurrently for both streaming and non-streaming
input_check = self._client._run_stage_guardrails(
Expand Down
77 changes: 31 additions & 46 deletions src/guardrails/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class Responses:
"""Responses API with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize Responses resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

def create(
Expand All @@ -23,7 +28,7 @@ def create(
stream: bool = False,
tools: list[dict] | None = None,
suppress_tripwire: bool = False,
**kwargs
**kwargs,
):
"""Create response with guardrails (synchronous).

Expand All @@ -44,9 +49,7 @@ def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -83,14 +86,7 @@ def create(
suppress_tripwire=suppress_tripwire,
)

def parse(
self,
input: list[dict[str, str]],
model: str,
text_format: type[BaseModel],
suppress_tripwire: bool = False,
**kwargs
):
def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], suppress_tripwire: bool = False, **kwargs):
"""Parse response with structured output and guardrails (synchronous)."""
latest_message, _ = self._client._extract_latest_user_message(input)

Expand All @@ -103,9 +99,7 @@ def parse(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -135,26 +129,30 @@ def parse(
def retrieve(self, response_id: str, suppress_tripwire: bool = False, **kwargs):
"""Retrieve response with output guardrail validation (synchronous)."""
# Get the response using the original OpenAI client
response = self._client._resource_client.responses.retrieve(
response_id, **kwargs
)
response = self._client._resource_client.responses.retrieve(response_id, **kwargs)

# Run output guardrails on the retrieved content
output_text = response.output_text if hasattr(response, "output_text") else ""
output_results = self._client._run_stage_guardrails(
"output", output_text, suppress_tripwire=suppress_tripwire
)
output_results = self._client._run_stage_guardrails("output", output_text, suppress_tripwire=suppress_tripwire)

# Return wrapped response with guardrail results
return self._client._create_guardrails_response(
response, [], [], output_results # preflight # input
response,
[],
[],
output_results, # preflight # input
)


class AsyncResponses:
"""Responses API with guardrails (async)."""

def __init__(self, client):
"""Initialize AsyncResponses resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

async def create(
Expand All @@ -164,7 +162,7 @@ async def create(
stream: bool = False,
tools: list[dict] | None = None,
suppress_tripwire: bool = False,
**kwargs
**kwargs,
) -> Any | AsyncIterator[Any]:
"""Create response with guardrails."""
# Determine latest user message text when a list of messages is provided
Expand All @@ -182,9 +180,7 @@ async def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Run input guardrails and LLM call in parallel
input_check = self._client._run_stage_guardrails(
Expand Down Expand Up @@ -220,13 +216,7 @@ async def create(
)

async def parse(
self,
input: list[dict[str, str]],
model: str,
text_format: type[BaseModel],
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], stream: bool = False, suppress_tripwire: bool = False, **kwargs
) -> Any | AsyncIterator[Any]:
"""Parse response with structured output and guardrails."""
latest_message, _ = self._client._extract_latest_user_message(input)
Expand All @@ -240,9 +230,7 @@ async def parse(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Run input guardrails and LLM call in parallel
input_check = self._client._run_stage_guardrails(
Expand Down Expand Up @@ -277,22 +265,19 @@ async def parse(
suppress_tripwire=suppress_tripwire,
)

async def retrieve(
self, response_id: str, suppress_tripwire: bool = False, **kwargs
):
async def retrieve(self, response_id: str, suppress_tripwire: bool = False, **kwargs):
"""Retrieve response with output guardrail validation."""
# Get the response using the original OpenAI client
response = await self._client._resource_client.responses.retrieve(
response_id, **kwargs
)
response = await self._client._resource_client.responses.retrieve(response_id, **kwargs)

# Run output guardrails on the retrieved content
output_text = response.output_text if hasattr(response, "output_text") else ""
output_results = await self._client._run_stage_guardrails(
"output", output_text, suppress_tripwire=suppress_tripwire
)
output_results = await self._client._run_stage_guardrails("output", output_text, suppress_tripwire=suppress_tripwire)

# Return wrapped response with guardrail results
return self._client._create_guardrails_response(
response, [], [], output_results # preflight # input
response,
[],
[],
output_results, # preflight # input
)
Loading