diff --git a/guardrails/classes/validation/validation_summary.py b/guardrails/classes/validation/validation_summary.py index a1c9b13d0..c73f5f4d8 100644 --- a/guardrails/classes/validation/validation_summary.py +++ b/guardrails/classes/validation/validation_summary.py @@ -12,10 +12,11 @@ class ValidationSummary(IValidationSummary, ArbitraryModel): def _generate_summaries_from_validator_logs( validator_logs: List[ValidatorLogs], ) -> Iterator["ValidationSummary"]: - """ - Generate a list of ValidationSummary objects from a list of - ValidatorLogs objects. Using an iterator to allow serializing - the summaries to other formats. + """Generate a list of ValidationSummary objects from a list of + ValidatorLogs objects. + + Using an iterator to allow serializing the summaries to other + formats. """ for log in validator_logs: validation_result = log.validation_result diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 29b79fbd0..aa1b50287 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -8,12 +8,11 @@ cast, ) - +from guardrails.validator_service import AsyncValidatorService from guardrails.actions.reask import SkeletonReAsk from guardrails.classes import ValidationOutcome from guardrails.classes.history import Call, Inputs, Iteration, Outputs from guardrails.classes.output_type import OutputTypes -from guardrails.constants import pass_status from guardrails.llm_providers import ( AsyncLiteLLMCallable, AsyncPromptCallableBase, @@ -28,6 +27,11 @@ from guardrails.run.async_runner import AsyncRunner from guardrails.telemetry import trace_async_stream_step from guardrails.hub_telemetry.hub_tracing import async_trace_stream +from guardrails.types import OnFailAction +from guardrails.classes.validation.validation_result import ( + PassResult, + FailResult, +) class AsyncStreamRunner(AsyncRunner, StreamRunner): @@ -133,49 +137,113 @@ async def async_step( parsed_fragment, validated_fragment, valid_op = None, None, None verified = set() validation_response = "" + validation_progress = {} + refrain_triggered = False + validation_passed = True if self.output_type == OutputTypes.STRING: + validator_service = AsyncValidatorService(self.disable_tracer) async for chunk in stream_output: chunk_text = self.get_chunk_text(chunk, api) _ = self.is_last_chunk(chunk, api) - fragment += chunk_text - parsed_chunk, move_to_next = self.parse( - chunk_text, output_schema, verified=verified - ) - if move_to_next: - continue - validated_fragment = await self.async_validate( + fragment += chunk_text + results = await validator_service.async_partial_validate( + chunk_text, + self.metadata, + self.validation_map, iteration, - index, - parsed_chunk, - output_schema, - validate_subschema=True, - stream=True, + "$", + "$", + True, ) - # TODO why? how does it happen in the other places we handle streams - if validated_fragment is None: - validated_fragment = "" - - if isinstance(validated_fragment, SkeletonReAsk): - raise ValueError( - "Received fragment schema is an invalid sub-schema " - "of the expected output JSON schema." + validators = self.validation_map["$"] or [] + # collect the result validated_chunk into validation progress + # per validator + for result in results: + validator_log = result.validator_logs # type: ignore + validator = next( + filter( + lambda x: x.rail_alias == validator_log.registered_name, + validators, + ), + None, ) + if ( + validator_log.validation_result + and validator_log.validation_result.validated_chunk + ): + is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore + is_refrain = ( + validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore + ) + if validator_log.validation_result.outcome == "fail": + validation_passed = False + reasks, valid_op = self.introspect( + validator_log.validation_result + ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) - reasks, valid_op = self.introspect(validated_fragment) - if reasks: - raise ValueError( - "Reasks are not yet supported with streaming. Please " - "remove reasks from schema or disable streaming." + if isinstance(validator_log.validation_result, PassResult): + chunk = validator_log.validation_result.validated_chunk + elif isinstance(validator_log.validation_result, FailResult): + if is_filter or is_refrain: + refrain_triggered = True + chunk = "" + else: + chunk = validator_service.perform_correction( + validator_log.validation_result, + validator_log.validation_result.validated_chunk, + validator, # type: ignore + rechecked_value=None, + ) # type: ignore + + if not hasattr( + validation_progress, validator_log.validator_name + ): + validation_progress[validator_log.validator_name] = "" + + validation_progress[validator_log.validator_name] += chunk + # if there is an entry for every validator + # run a merge and emit a validation outcome + if len(validation_progress) == len(validators): + if refrain_triggered: + current = "" + else: + merge_chunks = [] + for piece in validation_progress: + merge_chunks.append(validation_progress[piece]) + + current = validator_service.multi_merge(fragment, merge_chunks) + + vo = ValidationOutcome( + call_id=call_log.id, # type: ignore + raw_llm_output=fragment, + validated_output=current, + validation_passed=True, ) - validation_response += validated_fragment - passed = call_log.status == pass_status + fragment = "" + validation_progress = {} + refrain_triggered = False + + yield vo + + # if theres anything left merge and emit a chunk + if len(validation_progress) > 0: + merge_chunks = [] + for piece in validation_progress: + merge_chunks.append(validation_progress[piece]) + + current = validator_service.multi_merge(fragment, merge_chunks) yield ValidationOutcome( call_id=call_log.id, # type: ignore - raw_llm_output=chunk_text, - validated_output=validated_fragment, - validation_passed=passed, + raw_llm_output=fragment, + validated_output=current, + validation_passed=validation_passed, ) else: async for chunk in stream_output: diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py index 03ef3d5ba..52a2a530c 100644 --- a/guardrails/validator_service/async_validator_service.py +++ b/guardrails/validator_service/async_validator_service.py @@ -253,6 +253,38 @@ async def validate_child( return value, metadata + async def async_partial_validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> list[ValidatorRun]: + # Then validate the parent value + validators = validator_map.get(reference_path, []) + coroutines: List[Coroutine[Any, Any, ValidatorRun]] = [] + + for validator in validators: + coroutines.append( + self.run_validator( + iteration, + validator, + value, + metadata, + absolute_path, + stream=stream, + **kwargs, + ) + ) + + results = await asyncio.gather(*coroutines) + + return results + async def async_validate( self, value: Any, diff --git a/guardrails/validator_service/sequential_validator_service.py b/guardrails/validator_service/sequential_validator_service.py index f39a984f3..0685ca9a5 100644 --- a/guardrails/validator_service/sequential_validator_service.py +++ b/guardrails/validator_service/sequential_validator_service.py @@ -10,7 +10,6 @@ StreamValidationResult, ValidationResult, ) -from guardrails.merge import merge from guardrails.types import ValidatorMap, OnFailAction from guardrails.utils.exception_utils import UserFacingException from guardrails.classes.validation.validator_logs import ValidatorLogs @@ -108,14 +107,6 @@ def run_validators_stream( **kwargs, ) - # requires at least 2 validators - def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: - current = new_values.pop() - while len(new_values) > 0: - nextval = new_values.pop() - current = merge(current, nextval, original) - return current - def run_validators_stream_fix( self, iteration: Iteration, diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index 0ea2e120b..3b626c7ce 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -167,6 +167,14 @@ def run_validator( ) -> ValidatorRun: raise NotImplementedError + # requires at least 2 validators + def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: + current = new_values.pop() + while len(new_values) > 0: + nextval = new_values.pop() + current = merge(current, nextval, original) + return current + def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: new_vals = deepcopy(new_values) current = new_values.pop() diff --git a/tests/integration_tests/test_async_streaming.py b/tests/integration_tests/test_async_streaming.py index b0f2ed300..cb6fe18c0 100644 --- a/tests/integration_tests/test_async_streaming.py +++ b/tests/integration_tests/test_async_streaming.py @@ -21,6 +21,18 @@ ) from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII +POETRY_CHUNKS = [ + "John, under ", + "GOLDEN bridges", + ", roams,\n", + "SAN Francisco's ", + "hills, his HOME.\n", + "Dreams of", + " FOG, and salty AIR,\n", + "In his HEART", + ", he's always THERE.", +] + @register_validator(name="minsentencelength", data_type=["string", "list"]) class MinSentenceLengthValidator(Validator): @@ -131,21 +143,54 @@ async def gen(): self.completion_stream = gen() -POETRY_CHUNKS = [ - "John, under ", - "GOLDEN bridges", - ", roams,\n", - "SAN Francisco's ", - "hills, his HOME.\n", - "Dreams of", - " FOG, and salty AIR,\n", - "In his HEART", - ", he's always THERE.", -] +@pytest.mark.asyncio +async def test_async_streaming_fix_behavior_two_validators(mocker): + mocker.patch( + "litellm.acompletion", + return_value=Response(POETRY_CHUNKS), + ) + + guard = gd.AsyncGuard().use_many( + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FIX), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = await guard( + model="gpt-3.5-turbo", + max_tokens=10, + temperature=0, + stream=True, + prompt=prompt, + ) + text = "" + original = "" + async for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + + assert ( + text + == """, under golden bridges, roams, + hills, his home. +dreams of fog, and salty air, +in his heart, he's always there.""" + ) + assert ( + original + == """John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) @pytest.mark.asyncio -async def test_filter_behavior(mocker): +async def test_async_streaming_filter_behavior(mocker): mocker.patch( "litellm.acompletion", return_value=Response(POETRY_CHUNKS), @@ -169,16 +214,18 @@ async def test_filter_behavior(mocker): prompt=prompt, ) - text = "" - final_res = None + validated = "" + raw_llm_output = "" + async for res in gen: - final_res = res - text += res.validated_output - - assert final_res.raw_llm_output == ", he's always THERE." - # TODO deep dive this - assert text == ( - "John, under GOLDEN bridges, roams,\n" - "SAN Francisco's Dreams of FOG, and salty AIR,\n" - "In his HEART" + validated += res.validated_output + raw_llm_output += res.raw_llm_output + + assert validated == "" + assert ( + raw_llm_output + == """John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" )