Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions guardrails/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def stream_validate(
)
if line:
json_output = json.loads(line)
if json_output.get("error"):
raise Exception(json_output.get("error").get("message"))
yield IValidationOutcome.from_dict(json_output)

def get_history(self, guard_name: str, call_id: str):
Expand Down
16 changes: 9 additions & 7 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,15 @@ async def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)
# TODO re-enable this once we have a way to get history
# from a multi-node server
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("AsyncGuard does not have an api client!")

Expand Down
7 changes: 6 additions & 1 deletion guardrails/classes/llm/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse:
stream_output = [str(so) for so in copy_2]

async_stream_output = None
if self.async_stream_output:
# dont do this again if already aiter-able were updating
# ourselves here so in memory
# this can cause issues
if self.async_stream_output and not hasattr(
self.async_stream_output, "__aiter__"
):
# tee doesn't work with async iterators
# This may be destructive
async_stream_output = []
Expand Down
2 changes: 1 addition & 1 deletion guardrails/cli/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def generate_template_config(
guard_instantiations = []

for i, guard in enumerate(template["guards"]):
guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])")
guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])")
guard_instantiations = "\n".join(guard_instantiations)
# Interpolate variables
output_content = template_content.format(
Expand Down
2 changes: 1 addition & 1 deletion guardrails/cli/hub/template_config.py.template
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from guardrails import Guard
from guardrails import AsyncGuard, Guard
from guardrails.hub import {VALIDATOR_IMPORTS}

try:
Expand Down
26 changes: 15 additions & 11 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
error="The response from the server was empty!",
)

guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend([Call.from_interface(call) for call in guard_history])
# TODO reenable this when we have history support in
# multi-node server environments
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend([Call.from_interface(call) for call in guard_history])

validation_summaries = []
if self.history.last and self.history.last.iterations.last:
Expand Down Expand Up @@ -1281,13 +1283,15 @@ def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)

# TODO reenable this when sever supports multi-node history
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("Guard does not have an api client!")

Expand Down
8 changes: 8 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def _invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = completion(
model=model,
*args,
Expand Down Expand Up @@ -1088,6 +1092,10 @@ async def invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = await acompletion(
*args,
**kwargs,
Expand Down
6 changes: 4 additions & 2 deletions guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def async_step(
_ = self.is_last_chunk(chunk, api)

fragment += chunk_text

results = await validator_service.async_partial_validate(
chunk_text,
self.metadata,
Expand All @@ -157,7 +158,8 @@ async def async_step(
"$",
True,
)
validators = self.validation_map["$"] or []
validators = self.validation_map.get("$", [])

# collect the result validated_chunk into validation progress
# per validator
for result in results:
Expand Down Expand Up @@ -210,7 +212,7 @@ async def async_step(
validation_progress[validator_log.validator_name] += chunk
# if there is an entry for every validator
# run a merge and emit a validation outcome
if len(validation_progress) == len(validators):
if len(validation_progress) == len(validators) or len(validators) == 0:
if refrain_triggered:
current = ""
else:
Expand Down
20 changes: 15 additions & 5 deletions guardrails/telemetry/guard_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,18 @@ def trace_stream_guard(
res = next(result) # type: ignore
# FIXME: This should only be called once;
# Accumulate the validated output and call at the end
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
if not guard_span.is_recording():
# Assuming you have a tracer instance
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"stream_guard_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
except StopIteration:
next_exists = False

Expand Down Expand Up @@ -180,6 +189,7 @@ def trace_guard_execution(
result, ValidationOutcome
):
return trace_stream_guard(guard_span, result, history)

add_guard_attributes(guard_span, history, result)
add_user_attributes(guard_span)
return result
Expand All @@ -204,14 +214,14 @@ async def trace_async_stream_guard(
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"new_guard_span", # type: ignore
"async_stream_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span

add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
yield res
except StopIteration:
next_exists = False
except StopAsyncIteration:
Expand Down
5 changes: 5 additions & 0 deletions guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs):
) as call_span:
try:
response = fn(*args, **kwargs)
if isinstance(response, LLMResponse) and (
response.async_stream_output or response.stream_output
):
# TODO: Iterate, add a call attr each time
return response
add_call_attributes(call_span, response, *args, **kwargs)
return response
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions guardrails/validator_service/validator_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def run_validator(

# requires at least 2 validators
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
if len(new_values) == 0:
return original
current = new_values.pop()
while len(new_values) > 0:
nextval = new_values.pop()
Expand Down
4 changes: 2 additions & 2 deletions server_ci/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from guardrails import Guard
from guardrails import AsyncGuard

try:
file_path = os.path.join(os.getcwd(), "guard-template.json")
Expand All @@ -11,4 +11,4 @@
SystemExit(1)

# instantiate guards
guard0 = Guard.from_dict(guards[0])
guard0 = AsyncGuard.from_dict(guards[0])
55 changes: 54 additions & 1 deletion server_ci/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import openai
import os
import pytest
from guardrails import Guard, settings
from guardrails import AsyncGuard, Guard, settings

# OpenAI compatible Guardrails API Guard
openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/"
Expand Down Expand Up @@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed,
assert validation_outcome.validated_output == validation_output


@pytest.mark.asyncio
async def test_async_guard_validation():
settings.use_server = True
guard = AsyncGuard(name="test-guard")

validation_outcome = await guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
temperature=0.0,
)

assert validation_outcome.validation_passed is True # noqa: E712
assert validation_outcome.validated_output == "Citrus fruit,"


@pytest.mark.asyncio
async def test_async_streaming_guard_validation():
settings.use_server = True
guard = AsyncGuard(name="test-guard")

async_iterator = await guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
stream=True,
temperature=0.0,
)

full_output = ""
async for validation_chunk in async_iterator:
full_output += validation_chunk.validated_output

assert full_output == "Citrus fruit,Citrus fruit,"


@pytest.mark.asyncio
async def test_sync_streaming_guard_validation():
settings.use_server = True
guard = Guard(name="test-guard")

iterator = guard(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
stream=True,
temperature=0.0,
)

full_output = ""
for validation_chunk in iterator:
full_output += validation_chunk.validated_output

assert full_output == "Citrus fruit,Citrus fruit,"


@pytest.mark.parametrize(
"message_content, output, validation_passed, error",
[
Expand Down
Loading