Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,11 @@ async def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
# Here we have an async generator
Expand All @@ -391,7 +395,11 @@ async def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
# Why are we using a different method here instead of just overriding?
Expand Down
20 changes: 18 additions & 2 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
from guardrails.cli.version import version_warnings_if_applicable


@hub_command.command()
# Quick note: This is the command for `guardrails hub install`. We change the name of
# the function def to prevent confusion, lest people import it directly and calling it
# with a string for package_uris instead of a list, which behaves oddly. If you need to
# call install from a script, please consider importing install from guardrails,
# not guardrails.cli.hub.install.
@hub_command.command(name="install")
@trace(name="guardrails-cli/hub/install")
def install(
def install_cli(
package_uris: List[str] = typer.Argument(
...,
help="URIs to the packages to install. Example: hub://guardrails/regex_match hub://guardrails/toxic_language",
Expand All @@ -33,6 +38,17 @@ def install(
),
):
try:
if isinstance(package_uris, str):
logger.error(
f"`install` in {__file__} was called with a string instead of "
"a list! This can happen if it is invoked directly instead of "
"being run via the CLI. Did you mean to import `from guardrails import "
"install` instead? Recovering..."
)
package_uris = [
package_uris,
]

from guardrails.hub.install import install_multiple

def confirm():
Expand Down
12 changes: 10 additions & 2 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,11 @@ def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
return runner(call_log=call_log, prompt_params=prompt_params)
Expand All @@ -884,7 +888,11 @@ def _exec(
output=llm_output,
base_model=self._base_model,
full_schema_reask=full_schema_reask,
disable_tracer=(not self._allow_metrics_collection),
disable_tracer=(
not self._allow_metrics_collection
if isinstance(self._allow_metrics_collection, bool)
else None
),
exec_options=self._exec_opts,
)
call = runner(call_log=call_log, prompt_params=prompt_params)
Expand Down
2 changes: 1 addition & 1 deletion guardrails/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def install(
Examples:
>>> RegexMatch = install("hub://guardrails/regex_match").RegexMatch

>>> install("hub://guardrails/regex_match);
>>> install("hub://guardrails/regex_match")
>>> import guardrails.hub.regex_match as regex_match
"""

Expand Down
7 changes: 4 additions & 3 deletions guardrails/hub_telemetry/hub_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
Dict,
Optional,
AsyncGenerator,
)

from opentelemetry.trace import Span
Expand Down Expand Up @@ -224,7 +225,7 @@ def wrapper(*args, **kwargs):
return decorator


async def _run_async_gen(fn, *args, **kwargs):
async def _run_async_gen(fn, *args, **kwargs) -> AsyncGenerator[Any, None]:
gen = fn(*args, **kwargs)
async for item in gen:
yield item
Expand All @@ -238,7 +239,7 @@ def async_trace_stream(
):
def decorator(fn):
@wraps(fn)
async def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs):
hub_telemetry = HubTelemetry()
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
with hub_telemetry._tracer.start_span(
Expand All @@ -252,7 +253,7 @@ async def wrapper(*args, **kwargs):
nonlocal origin
origin = origin if origin is not None else name
add_attributes(span, attrs, name, origin, *args, **kwargs)
return _run_async_gen(fn, *args, **kwargs)
return fn(*args, **kwargs)
else:
return fn(*args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ async def async_step(
validate_subschema=True,
stream=True,
)
# TODO why? how does it happen in the other places we handle streams
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nichwch any insights?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're asking about the None check on line 157, it might be a remnant of before we implemented generators all the way down. Before, it was possible for validation to return None of the validators haven't accumulated enough chunks to validate yet. For sync streaming, we've changed this so that the validation logic takes place in a generator and only emits results when enough chunks have been accumulated.

Looking into the async streaming code, it looks like we never changed that to use generators for the validation logic, so validators still emit Nones before they have accumulated enough chunks

if validated_fragment is None:
validated_fragment = ""

if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
Expand All @@ -165,7 +169,7 @@ async def async_step(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)
validation_response += cast(str, validated_fragment)
validation_response += validated_fragment
passed = call_log.status == pass_status
yield ValidationOutcome(
call_id=call_log.id, # type: ignore
Expand Down
19 changes: 16 additions & 3 deletions guardrails/telemetry/guard_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

from opentelemetry import context, trace
from opentelemetry.trace import StatusCode, Tracer, Span
from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer

from guardrails.settings import settings
from guardrails.classes.generic.stack import Stack
Expand All @@ -22,6 +22,10 @@
from guardrails.telemetry.common import add_user_attributes
from guardrails.version import GUARDRAILS_VERSION

import sys

if sys.version_info.minor < 10:
from guardrails.utils.polyfills import anext

# from sentence_transformers import SentenceTransformer
# import numpy as np
Expand Down Expand Up @@ -195,8 +199,17 @@ async def trace_async_stream_guard(
while next_exists:
try:
res = await anext(result) # type: ignore
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
if not guard_span.is_recording():
# Assuming you have a tracer instance
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"new_guard_span", links=[Link(guard_span.get_span_context())]
) as new_span:
guard_span = new_span

add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
except StopIteration:
next_exists = False
Expand Down
4 changes: 4 additions & 0 deletions guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from guardrails.utils.safe_get import safe_get
from guardrails.version import GUARDRAILS_VERSION

import sys

if sys.version_info.minor < 10:
from guardrails.utils.polyfills import anext

#########################################
### START Runner.step Instrumentation ###
Expand Down
1 change: 0 additions & 1 deletion guardrails/utils/hub_telemetry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def initialize_tracer(
"""Initializes a tracer for Guardrails Hub."""
if enabled is None:
enabled = settings.rc.enable_metrics or False

self._enabled = enabled
self._carrier = {}
self._service_name = service_name
Expand Down
179 changes: 179 additions & 0 deletions tests/integration_tests/test_async_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# 3 tests
# 1. Test streaming with OpenAICallable (mock openai.Completion.create)
# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create)
# 3. Test string schema streaming
# Using the LowerCase Validator, and a custom validator to show new streaming behavior
from typing import Any, Callable, Dict, List, Optional, Union

import asyncio
import pytest

import guardrails as gd
from guardrails.utils.casting_utils import to_int
from guardrails.validator_base import (
ErrorSpan,
FailResult,
OnFailAction,
PassResult,
ValidationResult,
Validator,
register_validator,
)
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII


@register_validator(name="minsentencelength", data_type=["string", "list"])
class MinSentenceLengthValidator(Validator):
def __init__(
self,
min: Optional[int] = None,
max: Optional[int] = None,
on_fail: Optional[Callable] = None,
):
super().__init__(
on_fail=on_fail,
min=min,
max=max,
)
self._min = to_int(min)
self._max = to_int(max)

def sentence_split(self, value):
return list(map(lambda x: x + ".", value.split(".")[:-1]))

def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
sentences = self.sentence_split(value)
error_spans = []
index = 0
for sentence in sentences:
if len(sentence) < self._min:
error_spans.append(
ErrorSpan(
start=index,
end=index + len(sentence),
reason=f"Sentence has length less than {self._min}. "
f"Please return a longer output, "
f"that is shorter than {self._max} characters.",
)
)
if len(sentence) > self._max:
error_spans.append(
ErrorSpan(
start=index,
end=index + len(sentence),
reason=f"Sentence has length greater than {self._max}. "
f"Please return a shorter output, "
f"that is shorter than {self._max} characters.",
)
)
index = index + len(sentence)
if len(error_spans) > 0:
return FailResult(
validated_chunk=value,
error_spans=error_spans,
error_message=f"Sentence has length less than {self._min}. "
f"Please return a longer output, "
f"that is shorter than {self._max} characters.",
)
return PassResult(validated_chunk=value)

def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult:
return super().validate_stream(chunk, metadata, **kwargs)


class Delta:
content: str

def __init__(self, content):
self.content = content


class Choice:
text: str
finish_reason: str
index: int
delta: Delta

def __init__(self, text, delta, finish_reason, index=0):
self.index = index
self.delta = delta
self.text = text
self.finish_reason = finish_reason


class MockOpenAIV1ChunkResponse:
choices: list
model: str

def __init__(self, choices, model):
self.choices = choices
self.model = model


class Response:
def __init__(self, chunks):
self.chunks = chunks

async def gen():
for chunk in self.chunks:
yield MockOpenAIV1ChunkResponse(
choices=[
Choice(
delta=Delta(content=chunk),
text=chunk,
finish_reason=None,
)
],
model="OpenAI model name",
)
await asyncio.sleep(0) # Yield control to the event loop

self.completion_stream = gen()


POETRY_CHUNKS = [
'"John, under ',
"GOLDEN bridges",
", roams,\n",
"SAN Francisco's ",
"hills, his HOME.\n",
"Dreams of",
" FOG, and salty AIR,\n",
"In his HEART",
", he's always THERE.",
]


@pytest.mark.asyncio
async def test_filter_behavior(mocker):
mocker.patch(
"litellm.acompletion",
return_value=Response(POETRY_CHUNKS),
)

guard = gd.AsyncGuard().use_many(
MockDetectPII(
on_fail=OnFailAction.FIX,
pii_entities="pii",
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
),
LowerCase(on_fail=OnFailAction.FILTER),
)
prompt = """Write me a 4 line poem about John in San Francisco.
Make every third word all caps."""
gen = await guard(
model="gpt-3.5-turbo",
max_tokens=10,
temperature=0,
stream=True,
prompt=prompt,
)

text = ""
final_res = None
async for res in gen:
final_res = res
text = text + res.validated_output

assert final_res.raw_llm_output == ", he's always THERE."
assert text == ""
Loading