Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,13 @@ select = [
"D", # pydocstyle
]
isort = { combine-as-imports = true, known-first-party = ["guardrails"] }
extend-ignore=[
"D100", # Missing docstring in public module
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D107", # Missing docstring in `__init__`
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.extend-per-file-ignores]
"tests/**" = ["E501"]
"tests/**" = ["E501", "D100", "D103", "D104"]
"examples/**" = ["D100", "D103", "D104"]

[tool.ruff.format]
docstring-code-format = true
Expand Down
54 changes: 34 additions & 20 deletions src/guardrails/resources/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,56 @@ class Chat:
"""Chat completions with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize Chat resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

@property
def completions(self):
"""Access chat completions API with guardrails.

Returns:
ChatCompletions: Chat completions interface with guardrail support.
"""
return ChatCompletions(self._client)


class AsyncChat:
"""Chat completions with guardrails (async)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize AsyncChat resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

@property
def completions(self):
"""Access async chat completions API with guardrails.

Returns:
AsyncChatCompletions: Async chat completions with guardrail support.
"""
return AsyncChatCompletions(self._client)


class ChatCompletions:
"""Chat completions interface with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize ChatCompletions interface.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

def create(
self,
messages: list[dict[str, str]],
model: str,
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
):
def create(self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs):
"""Create chat completion with guardrails (synchronous).

Runs preflight first, then executes input guardrails concurrently with the LLM call.
Expand All @@ -59,9 +77,7 @@ def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_messages = self._client._apply_preflight_modifications(
messages, preflight_results
)
modified_messages = self._client._apply_preflight_modifications(messages, preflight_results)

# Run input guardrails and LLM call concurrently using a thread for the LLM
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -102,15 +118,15 @@ class AsyncChatCompletions:
"""Async chat completions interface with guardrails."""

def __init__(self, client):
"""Initialize AsyncChatCompletions interface.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

async def create(
self,
messages: list[dict[str, str]],
model: str,
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs
) -> Any | AsyncIterator[Any]:
"""Create chat completion with guardrails."""
latest_message, _ = self._client._extract_latest_user_message(messages)
Expand All @@ -124,9 +140,7 @@ async def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_messages = self._client._apply_preflight_modifications(
messages, preflight_results
)
modified_messages = self._client._apply_preflight_modifications(messages, preflight_results)

# Run input guardrails and LLM call concurrently for both streaming and non-streaming
input_check = self._client._run_stage_guardrails(
Expand Down
77 changes: 31 additions & 46 deletions src/guardrails/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class Responses:
"""Responses API with guardrails (sync)."""

def __init__(self, client: GuardrailsBaseClient) -> None:
"""Initialize Responses resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

def create(
Expand All @@ -23,7 +28,7 @@ def create(
stream: bool = False,
tools: list[dict] | None = None,
suppress_tripwire: bool = False,
**kwargs
**kwargs,
):
"""Create response with guardrails (synchronous).

Expand All @@ -44,9 +49,7 @@ def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -83,14 +86,7 @@ def create(
suppress_tripwire=suppress_tripwire,
)

def parse(
self,
input: list[dict[str, str]],
model: str,
text_format: type[BaseModel],
suppress_tripwire: bool = False,
**kwargs
):
def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], suppress_tripwire: bool = False, **kwargs):
"""Parse response with structured output and guardrails (synchronous)."""
latest_message, _ = self._client._extract_latest_user_message(input)

Expand All @@ -103,9 +99,7 @@ def parse(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -135,26 +129,30 @@ def parse(
def retrieve(self, response_id: str, suppress_tripwire: bool = False, **kwargs):
"""Retrieve response with output guardrail validation (synchronous)."""
# Get the response using the original OpenAI client
response = self._client._resource_client.responses.retrieve(
response_id, **kwargs
)
response = self._client._resource_client.responses.retrieve(response_id, **kwargs)

# Run output guardrails on the retrieved content
output_text = response.output_text if hasattr(response, "output_text") else ""
output_results = self._client._run_stage_guardrails(
"output", output_text, suppress_tripwire=suppress_tripwire
)
output_results = self._client._run_stage_guardrails("output", output_text, suppress_tripwire=suppress_tripwire)

# Return wrapped response with guardrail results
return self._client._create_guardrails_response(
response, [], [], output_results # preflight # input
response,
[],
[],
output_results, # preflight # input
)


class AsyncResponses:
"""Responses API with guardrails (async)."""

def __init__(self, client):
"""Initialize AsyncResponses resource.

Args:
client: GuardrailsBaseClient instance with configured guardrails.
"""
self._client = client

async def create(
Expand All @@ -164,7 +162,7 @@ async def create(
stream: bool = False,
tools: list[dict] | None = None,
suppress_tripwire: bool = False,
**kwargs
**kwargs,
) -> Any | AsyncIterator[Any]:
"""Create response with guardrails."""
# Determine latest user message text when a list of messages is provided
Expand All @@ -182,9 +180,7 @@ async def create(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Run input guardrails and LLM call in parallel
input_check = self._client._run_stage_guardrails(
Expand Down Expand Up @@ -220,13 +216,7 @@ async def create(
)

async def parse(
self,
input: list[dict[str, str]],
model: str,
text_format: type[BaseModel],
stream: bool = False,
suppress_tripwire: bool = False,
**kwargs
self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], stream: bool = False, suppress_tripwire: bool = False, **kwargs
) -> Any | AsyncIterator[Any]:
"""Parse response with structured output and guardrails."""
latest_message, _ = self._client._extract_latest_user_message(input)
Expand All @@ -240,9 +230,7 @@ async def parse(
)

# Apply pre-flight modifications (PII masking, etc.)
modified_input = self._client._apply_preflight_modifications(
input, preflight_results
)
modified_input = self._client._apply_preflight_modifications(input, preflight_results)

# Run input guardrails and LLM call in parallel
input_check = self._client._run_stage_guardrails(
Expand Down Expand Up @@ -277,22 +265,19 @@ async def parse(
suppress_tripwire=suppress_tripwire,
)

async def retrieve(
self, response_id: str, suppress_tripwire: bool = False, **kwargs
):
async def retrieve(self, response_id: str, suppress_tripwire: bool = False, **kwargs):
"""Retrieve response with output guardrail validation."""
# Get the response using the original OpenAI client
response = await self._client._resource_client.responses.retrieve(
response_id, **kwargs
)
response = await self._client._resource_client.responses.retrieve(response_id, **kwargs)

# Run output guardrails on the retrieved content
output_text = response.output_text if hasattr(response, "output_text") else ""
output_results = await self._client._run_stage_guardrails(
"output", output_text, suppress_tripwire=suppress_tripwire
)
output_results = await self._client._run_stage_guardrails("output", output_text, suppress_tripwire=suppress_tripwire)

# Return wrapped response with guardrail results
return self._client._create_guardrails_response(
response, [], [], output_results # preflight # input
response,
[],
[],
output_results, # preflight # input
)
Loading