Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions guardrails/classes/validation/validation_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ class ValidationSummary(IValidationSummary, ArbitraryModel):
def _generate_summaries_from_validator_logs(
validator_logs: List[ValidatorLogs],
) -> Iterator["ValidationSummary"]:
"""
Generate a list of ValidationSummary objects from a list of
ValidatorLogs objects. Using an iterator to allow serializing
the summaries to other formats.
"""Generate a list of ValidationSummary objects from a list of
ValidatorLogs objects.

Using an iterator to allow serializing the summaries to other
formats.
"""
for log in validator_logs:
validation_result = log.validation_result
Expand Down
132 changes: 100 additions & 32 deletions guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
cast,
)


from guardrails.validator_service import AsyncValidatorService
from guardrails.actions.reask import SkeletonReAsk
from guardrails.classes import ValidationOutcome
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
from guardrails.classes.output_type import OutputTypes
from guardrails.constants import pass_status
from guardrails.llm_providers import (
AsyncLiteLLMCallable,
AsyncPromptCallableBase,
Expand All @@ -28,6 +27,11 @@
from guardrails.run.async_runner import AsyncRunner
from guardrails.telemetry import trace_async_stream_step
from guardrails.hub_telemetry.hub_tracing import async_trace_stream
from guardrails.types import OnFailAction
from guardrails.classes.validation.validation_result import (
PassResult,
FailResult,
)


class AsyncStreamRunner(AsyncRunner, StreamRunner):
Expand Down Expand Up @@ -133,49 +137,113 @@ async def async_step(
parsed_fragment, validated_fragment, valid_op = None, None, None
verified = set()
validation_response = ""
validation_progress = {}
refrain_triggered = False
validation_passed = True

if self.output_type == OutputTypes.STRING:
validator_service = AsyncValidatorService(self.disable_tracer)
async for chunk in stream_output:
chunk_text = self.get_chunk_text(chunk, api)
_ = self.is_last_chunk(chunk, api)
fragment += chunk_text

parsed_chunk, move_to_next = self.parse(
chunk_text, output_schema, verified=verified
)
if move_to_next:
continue
validated_fragment = await self.async_validate(
fragment += chunk_text
results = await validator_service.async_partial_validate(
chunk_text,
self.metadata,
self.validation_map,
iteration,
index,
parsed_chunk,
output_schema,
validate_subschema=True,
stream=True,
"$",
"$",
True,
)
# TODO why? how does it happen in the other places we handle streams
if validated_fragment is None:
validated_fragment = ""

if isinstance(validated_fragment, SkeletonReAsk):
raise ValueError(
"Received fragment schema is an invalid sub-schema "
"of the expected output JSON schema."
validators = self.validation_map["$"] or []
# collect the result validated_chunk into validation progress
# per validator
for result in results:
validator_log = result.validator_logs # type: ignore
validator = next(
filter(
lambda x: x.rail_alias == validator_log.registered_name,
validators,
),
None,
)
if (
validator_log.validation_result
and validator_log.validation_result.validated_chunk
):
is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore
is_refrain = (
validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore
)
if validator_log.validation_result.outcome == "fail":
validation_passed = False
reasks, valid_op = self.introspect(
validator_log.validation_result
)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
)

reasks, valid_op = self.introspect(validated_fragment)
if reasks:
raise ValueError(
"Reasks are not yet supported with streaming. Please "
"remove reasks from schema or disable streaming."
if isinstance(validator_log.validation_result, PassResult):
chunk = validator_log.validation_result.validated_chunk
elif isinstance(validator_log.validation_result, FailResult):
if is_filter or is_refrain:
refrain_triggered = True
chunk = ""
else:
chunk = validator_service.perform_correction(
validator_log.validation_result,
validator_log.validation_result.validated_chunk,
validator, # type: ignore
rechecked_value=None,
) # type: ignore

if not hasattr(
validation_progress, validator_log.validator_name
):
validation_progress[validator_log.validator_name] = ""

validation_progress[validator_log.validator_name] += chunk
# if there is an entry for every validator
# run a merge and emit a validation outcome
if len(validation_progress) == len(validators):
if refrain_triggered:
current = ""
else:
merge_chunks = []
for piece in validation_progress:
merge_chunks.append(validation_progress[piece])

current = validator_service.multi_merge(fragment, merge_chunks)

vo = ValidationOutcome(
call_id=call_log.id, # type: ignore
raw_llm_output=fragment,
validated_output=current,
validation_passed=True,
)
validation_response += validated_fragment
passed = call_log.status == pass_status
fragment = ""
validation_progress = {}
refrain_triggered = False

yield vo

# if theres anything left merge and emit a chunk
if len(validation_progress) > 0:
merge_chunks = []
for piece in validation_progress:
merge_chunks.append(validation_progress[piece])

current = validator_service.multi_merge(fragment, merge_chunks)
yield ValidationOutcome(
call_id=call_log.id, # type: ignore
raw_llm_output=chunk_text,
validated_output=validated_fragment,
validation_passed=passed,
raw_llm_output=fragment,
validated_output=current,
validation_passed=validation_passed,
)
else:
async for chunk in stream_output:
Expand Down
32 changes: 32 additions & 0 deletions guardrails/validator_service/async_validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,38 @@ async def validate_child(

return value, metadata

async def async_partial_validate(
self,
value: Any,
metadata: dict,
validator_map: ValidatorMap,
iteration: Iteration,
absolute_path: str,
reference_path: str,
stream: Optional[bool] = False,
**kwargs,
) -> list[ValidatorRun]:
# Then validate the parent value
validators = validator_map.get(reference_path, [])
coroutines: List[Coroutine[Any, Any, ValidatorRun]] = []

for validator in validators:
coroutines.append(
self.run_validator(
iteration,
validator,
value,
metadata,
absolute_path,
stream=stream,
**kwargs,
)
)

results = await asyncio.gather(*coroutines)

return results

async def async_validate(
self,
value: Any,
Expand Down
9 changes: 0 additions & 9 deletions guardrails/validator_service/sequential_validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
StreamValidationResult,
ValidationResult,
)
from guardrails.merge import merge
from guardrails.types import ValidatorMap, OnFailAction
from guardrails.utils.exception_utils import UserFacingException
from guardrails.classes.validation.validator_logs import ValidatorLogs
Expand Down Expand Up @@ -108,14 +107,6 @@ def run_validators_stream(
**kwargs,
)

# requires at least 2 validators
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
current = new_values.pop()
while len(new_values) > 0:
nextval = new_values.pop()
current = merge(current, nextval, original)
return current

def run_validators_stream_fix(
self,
iteration: Iteration,
Expand Down
8 changes: 8 additions & 0 deletions guardrails/validator_service/validator_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ def run_validator(
) -> ValidatorRun:
raise NotImplementedError

# requires at least 2 validators
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
current = new_values.pop()
while len(new_values) > 0:
nextval = new_values.pop()
current = merge(current, nextval, original)
return current

def merge_results(self, original_value: Any, new_values: list[Any]) -> Any:
new_vals = deepcopy(new_values)
current = new_values.pop()
Expand Down
93 changes: 70 additions & 23 deletions tests/integration_tests/test_async_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
)
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII

POETRY_CHUNKS = [
"John, under ",
"GOLDEN bridges",
", roams,\n",
"SAN Francisco's ",
"hills, his HOME.\n",
"Dreams of",
" FOG, and salty AIR,\n",
"In his HEART",
", he's always THERE.",
]


@register_validator(name="minsentencelength", data_type=["string", "list"])
class MinSentenceLengthValidator(Validator):
Expand Down Expand Up @@ -131,21 +143,54 @@ async def gen():
self.completion_stream = gen()


POETRY_CHUNKS = [
"John, under ",
"GOLDEN bridges",
", roams,\n",
"SAN Francisco's ",
"hills, his HOME.\n",
"Dreams of",
" FOG, and salty AIR,\n",
"In his HEART",
", he's always THERE.",
]
@pytest.mark.asyncio
async def test_async_streaming_fix_behavior_two_validators(mocker):
mocker.patch(
"litellm.acompletion",
return_value=Response(POETRY_CHUNKS),
)

guard = gd.AsyncGuard().use_many(
MockDetectPII(
on_fail=OnFailAction.FIX,
pii_entities="pii",
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
),
LowerCase(on_fail=OnFailAction.FIX),
)
prompt = """Write me a 4 line poem about John in San Francisco.
Make every third word all caps."""
gen = await guard(
model="gpt-3.5-turbo",
max_tokens=10,
temperature=0,
stream=True,
prompt=prompt,
)
text = ""
original = ""
async for res in gen:
original = original + res.raw_llm_output
text = text + res.validated_output

assert (
text
== """<PERSON>, under golden bridges, roams,
<LOCATION> hills, his home.
dreams of fog, and salty air,
in his heart, he's always there."""
)
assert (
original
== """John, under GOLDEN bridges, roams,
SAN Francisco's hills, his HOME.
Dreams of FOG, and salty AIR,
In his HEART, he's always THERE."""
)


@pytest.mark.asyncio
async def test_filter_behavior(mocker):
async def test_async_streaming_filter_behavior(mocker):
mocker.patch(
"litellm.acompletion",
return_value=Response(POETRY_CHUNKS),
Expand All @@ -169,16 +214,18 @@ async def test_filter_behavior(mocker):
prompt=prompt,
)

text = ""
final_res = None
validated = ""
raw_llm_output = ""

async for res in gen:
final_res = res
text += res.validated_output

assert final_res.raw_llm_output == ", he's always THERE."
# TODO deep dive this
assert text == (
"John, under GOLDEN bridges, roams,\n"
"SAN Francisco's Dreams of FOG, and salty AIR,\n"
"In his HEART"
validated += res.validated_output
raw_llm_output += res.raw_llm_output

assert validated == ""
assert (
raw_llm_output
== """John, under GOLDEN bridges, roams,
SAN Francisco's hills, his HOME.
Dreams of FOG, and salty AIR,
In his HEART, he's always THERE."""
)
Loading