diff --git a/docs/concepts/parallelization.md b/docs/concepts/concurrency.md similarity index 54% rename from docs/concepts/parallelization.md rename to docs/concepts/concurrency.md index 1c4a069e5..bbe23fad8 100644 --- a/docs/concepts/parallelization.md +++ b/docs/concepts/concurrency.md @@ -1,11 +1,11 @@ -# Parallelization +# Concurrency ## And the Orchestration of Guard Executions This document is a description of the current implementation of the Guardrails' validation loop. It attempts to explain the current patterns used with some notes on why those patterns were accepted at the time of implementation and potential future optimizations. It is _not_ meant to be prescriptive as there can, and will, be improvements made in future versions. In general you will find that our approach to performance is two fold: 1. Complete computationally cheaper, static checks first and exit early to avoid spending time and resources on more expensive checks that are unlikely to pass when the former fail. -2. Parallelize processing where possible. +2. Run processes concurrently where possible. ## Background: The Validation Loop When a Guard is executed, that is called via `guard()`, `guard.parse()`, `guard.validate()`, etc., it goes through an internal process that has the following steps: @@ -60,7 +60,7 @@ Besides handling asynchronous calls to the LLM, using an `AsyncGuard` also ensur * An asyncio event loop is available. * The asyncio event loop is not taken/already running. -## Validation Orchestration and Parallelization +## Validation Orchestration and Concurrency ### Structured Data Validation We perform validation with a "deep-first" approach. This has no meaning for unstructured text output since there is only one value, but for structured output it means that the objects are validated from the inside out. @@ -79,7 +79,7 @@ Take the below structure as an example: } ``` -As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated as follows: +As of versions v0.4.x and v0.5.x of Guardrails, the above object would be validated as follows: 1. foo.baz 2. foo.bez @@ -88,11 +88,134 @@ As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated 5. bar.buz 6. bar - > NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above in parallel at once instead of performing them in steps based on key path. + > NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above concurrently instead of performing them in steps based on key path. When synchronous validation occurs as defined in [Benefits of AsyncGuard](#benefits-of-async-guard), the validators for each property would be run in the order they are defined on the schema. That also means that any on fail actions are applied in that same order. -When asynchronous validation occurs, there are multiple levels of parallelization possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen in parallel via the asyncio event loop. Second, within the validation for each property, if the validators have `run_in_separate_process` set to `True`, they are run in parallel via multiprocessing. This multiprocessing is capped to the process count specified by the `GUARDRAILS_PROCESS_COUNT` environment variable which defaults to 10. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to set this environment variable to 1. +When asynchronous validation occurs, there are multiple levels of concurrency possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen concurrently via the asyncio event loop. Second, the validators on any given property are also run concurrently via the event loop. For validators that only define a synchronous `validate` method, calls to this method are run in the event loops default executor. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to either set the executor to a thread processor instead or limit validation to running synchronously by setting `GUARDRAILS_PROCESS_COUNT=1` or `GUARDRAILS_RUN_SYNC=true`. ### Unstructured Data Validation -When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run in parallel utilizing multiprocessing when `run_in_separate_process` is set to `True` on the validators. \ No newline at end of file +When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run concurrently utilizing the event loop. + +### Handling Failures During Async Concurrency +The Guardrails validation loop is opinionated about how it handles failures when running validators concurrently so that it spends the least amount of time processing an output that would result in a failure. It's behavior comes down to when and what it returns based on the [corrective action](/how_to_guides/custom_validators#on-fail) specified on a validator. Corrective actions are processed concurrently since they are specific to a given validator on a given property. This means that interruptive corrective actions, namely `EXCEPTION`, will be the first corrective action enforced because the exception is raised as soon as the failure is evaluated. The remaining actions are handled in the following order after all futures are collected from the validation of a specific property: +1. `FILTER` and `REFRAIN` +2. `REASK` +3. `FIX` + + \*_NOTE:_ `NOOP` Does not require any special handling because it does not alter the value. + + \*_NOTE:_ `FIX_REASK` Will fall into either the `REASK` or `FIX` bucket based on if the fixed value passes the second round of validation. + +This means that if any validator with `on_fail=OnFailAction.EXCEPTION` returns a `FailResult`, then Guardrails will raise a `ValidationError` interrupting the process. + +If any validator on a specific property which has `on_fail=OnFailAction.FILTER` or `on_fail=OnFailAction.REFRAIN` returns a `FailResult`, whichever of these is the first to finish will the returned early as the value for that property, + +If any validator on a specific property which has `on_fail=OnFailAction.REASK` returns a `FailResult`, all reasks for that property will be merged and a `FieldReAsk` will be returned early as the value for that property. + +If any validator on a specific property which has `on_fail=OnFailAction.FIX` returns a `FailResult`, all fix values for that property will be merged and the result of that merge will be returned as the value for that property. + +Custom on_fail handlers will fall into one of the above actions based on what it returns; i.e. if it returns an updated value it's considered a `FIX`, if it returns an instance of `Filter` then `FILTER`, etc.. + +Let's look at an example. We'll keep the validation logic simple and write out some assertions to demonstrate the evaluation order discussed above. + +```py +import asyncio +from random import randint +from typing import Optional +from guardrails import AsyncGuard, ValidationOutcome +from guardrails.errors import ValidationError +from guardrails.validators import ( + Validator, + register_validator, + ValidationResult, + PassResult, + FailResult +) + +@register_validator(name='custom/contains', data_type='string') +class Contains(Validator): + def __init__(self, match_value: str, **kwargs): + super().__init__( + match_value=match_value, + **kwargs + ) + self.match_value = match_value + + def validate(self, value, metadata = {}) -> ValidationResult: + if self.match_value in value: + return PassResult() + + fix_value = None + if self.on_fail_descriptor == 'fix': + # Insert the match_value into the value at a random index + insertion = randint(0, len(value)) + fix_value = f"{value[:insertion]}{self.match_value}{value[insertion:]}" + + return FailResult( + error_message=f'Value must contain {self.match_value}', + fix_value=fix_value + ) + +exception_validator = Contains("a", on_fail='exception') +filter_validator = Contains("b", on_fail='filter') +refrain_validator = Contains("c", on_fail='refrain') +reask_validator_1 = Contains("d", on_fail='reask') +reask_validator_2 = Contains("e", on_fail='reask') +fix_validator_1 = Contains("f", on_fail='fix') +fix_validator_2 = Contains("g", on_fail='fix') + +guard = AsyncGuard().use_many( + exception_validator, + filter_validator, + refrain_validator, + reask_validator_1, + reask_validator_2, + fix_validator_1, + fix_validator_2 +) + +### Trigger the exception validator ### +error = None +result: Optional[ValidationOutcome] = None +try: + result = asyncio.run(guard.validate("z", metadata={})) + +except ValidationError as e: + error = e + +assert result is None +assert error is not None +assert str(error) == "Validation failed for field with errors: Value must contain a" + + + +### Trigger the Filter and Refrain validators ### +result = asyncio.run(guard.validate("a", metadata={})) + +assert result.validation_passed is False +# The output was filtered or refrained +assert result.validated_output is None +assert result.reask is None + + + +### Trigger the Reask validator ### +result = asyncio.run(guard.validate("abc", metadata={})) + +assert result.validation_passed is False +# If allowed, a ReAsk would have occured +assert result.reask is not None +error_messages = [f.error_message for f in result.reask.fail_results] +assert error_messages == ["Value must contain d", "Value must contain e"] + + +### Trigger the Fix validator ### +result = asyncio.run(guard.validate("abcde", metadata={})) + +assert result.validation_passed is True +# The fix values have been merged +assert "f" in result.validated_output +assert "g" in result.validated_output +print(result.validated_output) +``` \ No newline at end of file diff --git a/docs/faq.md b/docs/faq.md index 5d92c9a44..72e5e8ff9 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -58,7 +58,7 @@ You can override the `fix` behavior by passing it as a function to the Guard obj ```python from guardrails import Guard -def fix_is_cake(value, metadata): +def fix_is_cake(value, fail_result: FailResult): return "IT IS cake" guard = Guard().use(is_cake, on_fail=fix_is_cake) diff --git a/docs/how_to_guides/custom_validators.md b/docs/how_to_guides/custom_validators.md index ee113ae64..5c5446605 100644 --- a/docs/how_to_guides/custom_validators.md +++ b/docs/how_to_guides/custom_validators.md @@ -75,6 +75,7 @@ Validators ship with several out of the box `on_fail` policies. The `OnFailActio | `OnFailAction.NOOP` | Do nothing. The failure will still be recorded in the logs, but no corrective action will be taken. | | `OnFailAction.EXCEPTION` | Raise an exception when validation fails. | | `OnFailAction.FIX_REASK` | First, fix the generated output deterministically, and then rerun validation with the deterministically fixed output. If validation fails, then perform reasking. | +| `OnFailAction.CUSTOM` | This action is set internally when the validator is passed a custom function to handle failures. The function is called with the value that failed validation and the FailResult returned from the Validator. i.e. the custom on fail handler must implement the method signature `def on_fail(value: Any, fail_result: FailResult) -> Any` | In the code below, a `fix_value` will be supplied in the `FailResult`. This value will represent a programmatic fix that can be applied to the output if `on_fail='fix'` is passed during validator initialization. ```py diff --git a/docusaurus/sidebars.js b/docusaurus/sidebars.js index 1cce74653..ddb5a7034 100644 --- a/docusaurus/sidebars.js +++ b/docusaurus/sidebars.js @@ -66,7 +66,7 @@ const sidebars = { "concepts/streaming_fixes", ], }, - "concepts/parallelization", + "concepts/concurrency", "concepts/logs", "concepts/telemetry", "concepts/error_remediation", diff --git a/guardrails/__init__.py b/guardrails/__init__.py index 7c2f5f255..6dcb563b7 100644 --- a/guardrails/__init__.py +++ b/guardrails/__init__.py @@ -10,6 +10,7 @@ from guardrails.validator_base import Validator, register_validator from guardrails.settings import settings from guardrails.hub.install import install +from guardrails.classes.validation_outcome import ValidationOutcome __all__ = [ "Guard", @@ -25,4 +26,5 @@ "Instructions", "settings", "install", + "ValidationOutcome", ] diff --git a/guardrails/classes/generic/default_json_encoder.py b/guardrails/classes/generic/default_json_encoder.py new file mode 100644 index 000000000..1319c0cd5 --- /dev/null +++ b/guardrails/classes/generic/default_json_encoder.py @@ -0,0 +1,21 @@ +from datetime import datetime +from dataclasses import asdict, is_dataclass +from pydantic import BaseModel +from json import JSONEncoder + + +class DefaultJSONEncoder(JSONEncoder): + def default(self, o): + if hasattr(o, "to_dict"): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump() + elif is_dataclass(o): + return asdict(o) + elif isinstance(o, set): + return list(o) + elif isinstance(o, datetime): + return o.isoformat() + elif hasattr(o, "__dict__"): + return o.__dict__ + return super().default(o) diff --git a/guardrails/merge.py b/guardrails/merge.py index e25acd3ce..07e3d9487 100644 --- a/guardrails/merge.py +++ b/guardrails/merge.py @@ -1,4 +1,5 @@ # SOURCE: https://github.com/spyder-ide/three-merge/blob/master/three_merge/merge.py +from typing import Optional from diff_match_patch import diff_match_patch # Constants @@ -10,7 +11,12 @@ ADDITION = 1 -def merge(source: str, target: str, base: str) -> str: +def merge( + source: Optional[str], target: Optional[str], base: Optional[str] +) -> Optional[str]: + if source is None or target is None or base is None: + return None + diff1_l = DIFFER.diff_main(base, source) diff2_l = DIFFER.diff_main(base, target) @@ -75,7 +81,7 @@ def merge(source: str, target: str, base: str) -> str: invariant = "" target = (target_status, target_text) # type: ignore if advance: - prev_source_text = source[1] + prev_source_text = source[1] # type: ignore source = next(diff1, None) # type: ignore elif len(source_text) < len(target_text): # Addition performed by source @@ -119,7 +125,7 @@ def merge(source: str, target: str, base: str) -> str: invariant = "" source = (source_status, source_text) # type: ignore if advance: - prev_target_text = target[1] + prev_target_text = target[1] # type: ignore target = next(diff2, None) # type: ignore else: # Source and target are equal diff --git a/guardrails/telemetry/open_inference.py b/guardrails/telemetry/open_inference.py index 2fc21649d..266c36c2c 100644 --- a/guardrails/telemetry/open_inference.py +++ b/guardrails/telemetry/open_inference.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional -from guardrails.telemetry.common import get_span, serialize, to_dict +from guardrails.telemetry.common import get_span, to_dict +from guardrails.utils.serialization_utils import serialize def trace_operation( diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 9bcba353f..1b3bb346f 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -17,8 +17,9 @@ from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.stores.context import get_guard_name -from guardrails.telemetry.common import get_tracer, serialize +from guardrails.telemetry.common import get_tracer from guardrails.utils.safe_get import safe_get +from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION diff --git a/guardrails/telemetry/validator_tracing.py b/guardrails/telemetry/validator_tracing.py index 502e670f3..4f88cd23e 100644 --- a/guardrails/telemetry/validator_tracing.py +++ b/guardrails/telemetry/validator_tracing.py @@ -1,6 +1,7 @@ from functools import wraps from typing import ( Any, + Awaitable, Callable, Dict, Optional, @@ -12,10 +13,11 @@ from guardrails.settings import settings from guardrails.classes.validation.validation_result import ValidationResult -from guardrails.telemetry.common import get_tracer, serialize +from guardrails.telemetry.common import get_tracer from guardrails.telemetry.open_inference import trace_operation from guardrails.utils.casting_utils import to_string from guardrails.utils.safe_get import safe_get +from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION @@ -138,3 +140,65 @@ def trace_validator_wrapper(*args, **kwargs): return trace_validator_wrapper return trace_validator_decorator + + +def trace_async_validator( + validator_name: str, + obj_id: int, + on_fail_descriptor: Optional[str] = None, + tracer: Optional[Tracer] = None, + *, + validation_session_id: str, + **init_kwargs, +): + def trace_validator_decorator( + fn: Callable[..., Awaitable[Optional[ValidationResult]]], + ): + @wraps(fn) + async def trace_validator_wrapper(*args, **kwargs): + if not settings.disable_tracing: + current_otel_context = context.get_current() + _tracer = get_tracer(tracer) or trace.get_tracer( + "guardrails-ai", GUARDRAILS_VERSION + ) + validator_span_name = f"{validator_name}.validate" + with _tracer.start_as_current_span( + name=validator_span_name, # type: ignore + context=current_otel_context, # type: ignore + ) as validator_span: + try: + resp = await fn(*args, **kwargs) + add_validator_attributes( + *args, + validator_span=validator_span, + validator_name=validator_name, + obj_id=obj_id, + on_fail_descriptor=on_fail_descriptor, + result=resp, + init_kwargs=init_kwargs, + validation_session_id=validation_session_id, + **kwargs, + ) + return resp + except Exception as e: + validator_span.set_status( + status=StatusCode.ERROR, description=str(e) + ) + add_validator_attributes( + *args, + validator_span=validator_span, + validator_name=validator_name, + obj_id=obj_id, + on_fail_descriptor=on_fail_descriptor, + result=None, + init_kwargs=init_kwargs, + validation_session_id=validation_session_id, + **kwargs, + ) + raise e + else: + return await fn(*args, **kwargs) + + return trace_validator_wrapper + + return trace_validator_decorator diff --git a/guardrails/utils/serialization_utils.py b/guardrails/utils/serialization_utils.py new file mode 100644 index 000000000..d124a9069 --- /dev/null +++ b/guardrails/utils/serialization_utils.py @@ -0,0 +1,41 @@ +from datetime import datetime +import json +from typing import Any, Optional +import warnings + +from guardrails.classes.generic.default_json_encoder import DefaultJSONEncoder + + +# TODO: What other common cases we should consider? +def serialize(val: Any) -> Optional[str]: + try: + return json.dumps(val, cls=DefaultJSONEncoder) + except Exception as e: + warnings.warn(str(e)) + return None + + +# We want to do the oppisite of what we did in the DefaultJSONEncoder +# TODO: What's a good way to expose a configurable API for this? +# Do we wrap JSONDecoder with an extra layer to supply the original object? +def deserialize(original: Optional[Any], serialized: Optional[str]) -> Any: + try: + if original is None or serialized is None: + return None + + loaded_val = json.loads(serialized) + if isinstance(original, datetime): + return datetime.fromisoformat(loaded_val) + elif isinstance(original, set): + return set(original) + elif hasattr(original, "__class__"): + # TODO: Handle nested classes + # NOTE: nested pydantic classes already work + if isinstance(loaded_val, dict): + return original.__class__(**loaded_val) + elif isinstance(loaded_val, list): + return original.__class__(loaded_val) + return loaded_val + except Exception as e: + warnings.warn(str(e)) + return None diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index ce7484a18..c773e7c30 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -3,6 +3,8 @@ # - [ ] Maintain validator_base.py for exports but deprecate them # - [ ] Remove validator_base.py in 0.6.x +import asyncio +from functools import partial import inspect import logging from collections import defaultdict @@ -10,6 +12,7 @@ from string import Template from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from warnings import warn +import warnings import nltk import requests @@ -24,6 +27,7 @@ from guardrails.logger import logger from guardrails.remote_inference import remote_inference from guardrails.types.on_fail import OnFailAction +from guardrails.utils.safe_get import safe_get from guardrails.utils.hub_telemetry_utils import HubTelemetry # See: https://github.com/guardrails-ai/guardrails/issues/829 @@ -76,7 +80,7 @@ class Validator: def __init__( self, - on_fail: Optional[Union[Callable, OnFailAction]] = None, + on_fail: Optional[Union[Callable[[Any, FailResult], Any], OnFailAction]] = None, **kwargs, ): self.creds = Credentials.from_rc_file() @@ -124,7 +128,8 @@ def __init__( ) self.on_fail_method = None else: - self.on_fail_method = on_fail + self.on_fail_descriptor = OnFailAction.CUSTOM + self._set_on_fail_method(on_fail) # Store the kwargs for the validator. self._kwargs = kwargs @@ -133,6 +138,31 @@ def __init__( self.rail_alias in validators_registry ), f"Validator {self.__class__.__name__} is not registered. " + def _set_on_fail_method(self, on_fail: Callable[[Any, FailResult], Any]): + """Set the on_fail method for the validator.""" + on_fail_args = inspect.getfullargspec(on_fail) + second_arg = safe_get(on_fail_args.args, 1) + if second_arg is None: + raise ValueError( + "The on_fail method must take two arguments: " + "the value being validated and the FailResult." + ) + second_arg_type = on_fail_args.annotations.get(second_arg) + if second_arg_type == List[FailResult]: + warnings.warn( + "Specifying a List[FailResult] as the second argument" + " for a custom on_fail handler is deprecated. " + "Please use FailResult instead.", + DeprecationWarning, + ) + + def on_fail_wrapper(value: Any, fail_result: FailResult) -> Any: + return on_fail(value, [fail_result]) # type: ignore + + self.on_fail_method = on_fail_wrapper + else: + self.on_fail_method = on_fail + def _validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: """User implementable function. @@ -174,6 +204,19 @@ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: self._log_telemetry() return validation_result + async def async_validate( + self, value: Any, metadata: Dict[str, Any] + ) -> ValidationResult: + """Use this function if your validation logic requires asyncio. + + Guaranteed to work with AsyncGuard + + May not work with synchronous Guards if they are used within an + async context due to lack of available event loops. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.validate, value, metadata) + def _inference(self, model_input: Any) -> Any: """Calls either a local or remote inference engine for use in the validation call. @@ -255,6 +298,15 @@ def validate_stream( return validation_result + async def async_validate_stream( + self, chunk: Any, metadata: Dict[str, Any], **kwargs + ) -> Optional[ValidationResult]: + loop = asyncio.get_event_loop() + validate_stream_partial = partial( + self.validate_stream, chunk, metadata, **kwargs + ) + return await loop.run_in_executor(None, validate_stream_partial) + def _hub_inference_request( self, request_body: Union[dict, str], validation_endpoint: str ) -> Any: @@ -336,12 +388,12 @@ def get_args(self): def __call__(self, value): result = self.validate(value, {}) if isinstance(result, FailResult): - from guardrails.validator_service import ValidatorServiceBase + from guardrails.validator_service.validator_service_base import ( + ValidatorServiceBase, + ) validator_service = ValidatorServiceBase() - return validator_service.perform_correction( - [result], value, self, self.on_fail_descriptor - ) + return validator_service.perform_correction(result, value, self) return value def __eq__(self, other): diff --git a/guardrails/validator_service.py b/guardrails/validator_service.py deleted file mode 100644 index cc9f53cbb..000000000 --- a/guardrails/validator_service.py +++ /dev/null @@ -1,1070 +0,0 @@ -import asyncio -import itertools -import os -from concurrent.futures import ProcessPoolExecutor -from datetime import datetime -from typing import Any, Awaitable, Dict, Iterable, List, Optional, Tuple, Union, cast - -from guardrails.actions.filter import Filter, apply_filters -from guardrails.actions.refrain import Refrain, apply_refrain -from guardrails.classes.history import Iteration -from guardrails.classes.output_type import OutputTypes -from guardrails.classes.validation.validation_result import ( - FailResult, - PassResult, - StreamValidationResult, - ValidationResult, -) -from guardrails.errors import ValidationError -from guardrails.merge import merge -from guardrails.types import ValidatorMap, OnFailAction -from guardrails.utils.exception_utils import UserFacingException -from guardrails.utils.hub_telemetry_utils import HubTelemetry -from guardrails.classes.validation.validator_logs import ValidatorLogs -from guardrails.actions.reask import FieldReAsk, ReAsk -from guardrails.telemetry.legacy_validator_tracing import trace_validation_result -from guardrails.telemetry import trace_validator -from guardrails.validator_base import Validator - -ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] - - -def key_not_empty(key: str) -> bool: - return key is not None and len(str(key)) > 0 - - -class ValidatorServiceBase: - """Base class for validator services.""" - - def __init__(self, disable_tracer: Optional[bool] = True): - self._disable_tracer = disable_tracer - - # NOTE: This is avoiding an issue with multiprocessing. - # If we wrap the validate methods at the class level or anytime before - # loop.run_in_executor is called, multiprocessing fails with a Pickling error. - # This is a well known issue without any real solutions. - # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, - # but is relatively unsupported on Windows. - def execute_validator( - self, - validator: Validator, - value: Any, - metadata: Optional[Dict], - stream: Optional[bool] = False, - *, - validation_session_id: str, - **kwargs, - ) -> ValidatorResult: - validate_func = validator.validate_stream if stream else validator.validate - traced_validator = trace_validator( - validator_name=validator.rail_alias, - obj_id=id(validator), - on_fail_descriptor=validator.on_fail_descriptor, - validation_session_id=validation_session_id, - **validator._kwargs, - )(validate_func) - if stream: - result = traced_validator(value, metadata, **kwargs) - else: - result = traced_validator(value, metadata) - return result - - def perform_correction( - self, - results: List[FailResult], - value: Any, - validator: Validator, - on_fail_descriptor: Union[OnFailAction, str], - rechecked_value: Optional[ValidationResult] = None, - ): - if on_fail_descriptor == OnFailAction.FIX: - # FIXME: Should we still return fix_value if it is None? - # I think we should warn and return the original value. - return results[0].fix_value - elif on_fail_descriptor == OnFailAction.FIX_REASK: - # FIXME: Same thing here - fixed_value = results[0].fix_value - - if isinstance(rechecked_value, FailResult): - return FieldReAsk( - incorrect_value=fixed_value, - fail_results=results, - ) - - return fixed_value - if on_fail_descriptor == "custom": - if validator.on_fail_method is None: - raise ValueError("on_fail is 'custom' but on_fail_method is None") - return validator.on_fail_method(value, results) - if on_fail_descriptor == OnFailAction.REASK: - return FieldReAsk( - incorrect_value=value, - fail_results=results, - ) - if on_fail_descriptor == OnFailAction.EXCEPTION: - raise ValidationError( - "Validation failed for field with errors: " - + ", ".join([result.error_message for result in results]) - ) - if on_fail_descriptor == OnFailAction.FILTER: - return Filter() - if on_fail_descriptor == OnFailAction.REFRAIN: - return Refrain() - if on_fail_descriptor == OnFailAction.NOOP: - return value - else: - raise ValueError( - f"Invalid on_fail_descriptor {on_fail_descriptor}, " - f"expected 'fix' or 'exception'." - ) - - def before_run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - absolute_property_path: str, - ) -> ValidatorLogs: - validator_class_name = validator.__class__.__name__ - validator_logs = ValidatorLogs( - validator_name=validator_class_name, - value_before_validation=value, - registered_name=validator.rail_alias, - property_path=absolute_property_path, - # If we ever re-use validator instances across multiple properties, - # this will have to change. - instance_id=id(validator), - ) - iteration.outputs.validator_logs.append(validator_logs) - - start_time = datetime.now() - validator_logs.start_time = start_time - - return validator_logs - - def after_run_validator( - self, - validator: Validator, - validator_logs: ValidatorLogs, - result: Optional[ValidationResult], - ): - end_time = datetime.now() - validator_logs.validation_result = result - validator_logs.end_time = end_time - - if not self._disable_tracer: - # Get HubTelemetry singleton and create a new span to - # log the validator usage - _hub_telemetry = HubTelemetry() - _hub_telemetry.create_new_span( - span_name="/validator_usage", - attributes=[ - ("validator_name", validator.rail_alias), - ("validator_on_fail", validator.on_fail_descriptor), - ( - "validator_result", - result.outcome - if isinstance(result, ValidationResult) - else None, - ), - ], - is_parent=False, # This span will have no children - has_parent=True, # This span has a parent - ) - - return validator_logs - - def run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - metadata: Dict, - absolute_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> ValidatorLogs: - raise NotImplementedError - - -class SequentialValidatorService(ValidatorServiceBase): - def run_validator_sync( - self, - validator: Validator, - value: Any, - metadata: Dict, - validator_logs: ValidatorLogs, - stream: Optional[bool] = False, - *, - validation_session_id: str, - **kwargs, - ) -> Optional[ValidationResult]: - result = self.execute_validator( - validator, - value, - metadata, - stream, - validation_session_id=validation_session_id, - **kwargs, - ) - if asyncio.iscoroutine(result): - raise UserFacingException( - ValueError( - "Cannot use async validators with a synchronous Guard! " - f"Either use AsyncGuard or remove {validator_logs.validator_name}." - ) - ) - if result is None: - return result - return cast(ValidationResult, result) - - def run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - metadata: Dict, - property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> ValidatorLogs: - validator_logs = self.before_run_validator( - iteration, validator, value, property_path - ) - - result = self.run_validator_sync( - validator, - value, - metadata, - validator_logs, - stream, - validation_session_id=iteration.id, - **kwargs, - ) - - return self.after_run_validator(validator, validator_logs, result) - - def run_validators_stream( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], - metadata: Dict[str, Any], - absolute_property_path: str, - reference_property_path: str, - **kwargs, - ) -> Iterable[StreamValidationResult]: - validators = validator_map.get(reference_property_path, []) - for validator in validators: - if validator.on_fail_descriptor == OnFailAction.FIX: - return self.run_validators_stream_fix( - iteration, - validator_map, - value_stream, - metadata, - absolute_property_path, - reference_property_path, - **kwargs, - ) - return self.run_validators_stream_noop( - iteration, - validator_map, - value_stream, - metadata, - absolute_property_path, - reference_property_path, - **kwargs, - ) - - # requires at least 2 validators - def multi_merge(self, original: str, new_values: list[str]) -> str: - current = new_values.pop() - print("Fmerging these:", new_values) - while len(new_values) > 0: - nextval = new_values.pop() - current = merge(current, nextval, original) - print("\nFmerge result:", current) - return current - - def run_validators_stream_fix( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], - metadata: Dict[str, Any], - absolute_property_path: str, - reference_property_path: str, - **kwargs, - ) -> Iterable[StreamValidationResult]: - validators = validator_map.get(reference_property_path, []) - acc_output = "" - validator_partial_acc: dict[int, str] = {} - for validator in validators: - validator_partial_acc[id(validator)] = "" - last_chunk = None - last_chunk_validated = False - last_chunk_missing_validators = [] - refrain_triggered = False - for chunk, finished in value_stream: - original_text = chunk - acc_output += chunk - fixed_values = [] - last_chunk = chunk - last_chunk_missing_validators = [] - if refrain_triggered: - break - for validator in validators: - # reset chunk to original text - chunk = original_text - validator_logs = self.run_validator( - iteration, - validator, - chunk, - metadata, - absolute_property_path, - True, - remainder=finished, - **kwargs, - ) - result = validator_logs.validation_result - if result is None: - last_chunk_missing_validators.append(validator) - result = cast(ValidationResult, result) - # if we have a concrete result, log it in the validation map - if isinstance(result, FailResult): - is_filter = validator.on_fail_descriptor is OnFailAction.FILTER - is_refrain = validator.on_fail_descriptor is OnFailAction.REFRAIN - if is_filter or is_refrain: - refrain_triggered = True - break - rechecked_value = None - chunk = self.perform_correction( - [result], - chunk, - validator, - validator.on_fail_descriptor, - rechecked_value=rechecked_value, - ) - fixed_values.append(chunk) - validator_partial_acc[id(validator)] += chunk # type: ignore - elif isinstance(result, PassResult): - if ( - validator.override_value_on_pass - and result.value_override is not result.ValueOverrideSentinel - ): - chunk = result.value_override - else: - chunk = result.validated_chunk - fixed_values.append(chunk) - validator_partial_acc[id(validator)] += chunk # type: ignore - validator_logs.value_after_validation = chunk - if result and result.metadata is not None: - metadata = result.metadata - - if refrain_triggered: - # if we have a failresult from a refrain/filter validator, yield empty - yield StreamValidationResult( - chunk="", original_text=acc_output, metadata=metadata - ) - else: - # if every validator has yielded a concrete value, merge and yield - # only merge and yield if all validators have run - # TODO: check if only 1 validator - then skip merging - if len(fixed_values) == len(validators): - last_chunk_validated = True - values_to_merge = [] - for validator in validators: - values_to_merge.append(validator_partial_acc[id(validator)]) - merged_value = self.multi_merge(acc_output, values_to_merge) - # merged_value = self.multi_merge(acc_output, values_to_merge) - # reset validator_partial_acc - for validator in validators: - validator_partial_acc[id(validator)] = "" - yield StreamValidationResult( - chunk=merged_value, original_text=acc_output, metadata=metadata - ) - acc_output = "" - else: - last_chunk_validated = False - # handle case where LLM doesn't yield finished flag - # we need to validate remainder of accumulated chunks - if not last_chunk_validated and not refrain_triggered: - original_text = last_chunk - for validator in last_chunk_missing_validators: - last_log = self.run_validator( - iteration, - validator, - # use empty chunk - # validator has already accumulated the chunk from the first loop - "", - metadata, - absolute_property_path, - True, - remainder=True, - **kwargs, - ) - result = last_log.validation_result - if isinstance(result, FailResult): - rechecked_value = None - last_chunk = self.perform_correction( - [result], - last_chunk, - validator, - validator.on_fail_descriptor, - rechecked_value=rechecked_value, - ) - validator_partial_acc[id(validator)] += last_chunk # type: ignore - elif isinstance(result, PassResult): - if ( - validator.override_value_on_pass - and result.value_override is not result.ValueOverrideSentinel - ): - last_chunk = result.value_override - else: - last_chunk = result.validated_chunk - validator_partial_acc[id(validator)] += last_chunk # type: ignore - last_log.value_after_validation = last_chunk - if result and result.metadata is not None: - metadata = result.metadata - values_to_merge = [] - for validator in validators: - values_to_merge.append(validator_partial_acc[id(validator)]) - merged_value = self.multi_merge(acc_output, values_to_merge) - yield StreamValidationResult( - chunk=merged_value, - original_text=original_text, # type: ignore - metadata=metadata, # type: ignore - ) - # yield merged value - - def run_validators_stream_noop( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], - metadata: Dict[str, Any], - absolute_property_path: str, - reference_property_path: str, - **kwargs, - ) -> Iterable[StreamValidationResult]: - validators = validator_map.get(reference_property_path, []) - # Validate the field - # TODO: Under what conditions do we yield? - # When we have at least one non-None value? - # When we have all non-None values? - # Does this depend on whether we are fix or not? - for chunk, finished in value_stream: - original_text = chunk - for validator in validators: - validator_logs = self.run_validator( - iteration, - validator, - chunk, - metadata, - absolute_property_path, - True, - **kwargs, - ) - result = validator_logs.validation_result - result = cast(ValidationResult, result) - - if isinstance(result, FailResult): - rechecked_value = None - chunk = self.perform_correction( - [result], - chunk, - validator, - validator.on_fail_descriptor, - rechecked_value=rechecked_value, - ) - elif isinstance(result, PassResult): - if ( - validator.override_value_on_pass - and result.value_override is not result.ValueOverrideSentinel - ): - chunk = result.value_override - - validator_logs.value_after_validation = chunk - if result and result.metadata is not None: - metadata = result.metadata - # # TODO: Filter is no longer terminal, so we shouldn't yield, right? - # if isinstance(chunk, (Refrain, Filter, ReAsk)): - # yield chunk, metadata - yield StreamValidationResult( - chunk=chunk, original_text=original_text, metadata=metadata - ) - - def run_validators( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value: Any, - metadata: Dict[str, Any], - absolute_property_path: str, - reference_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, Dict[str, Any]]: - # Validate the field - validators = validator_map.get(reference_property_path, []) - for validator in validators: - if stream: - if validator.on_fail_descriptor is OnFailAction.REASK: - raise ValueError( - """Reask is not supported for stream validation, - only noop and exception are supported.""" - ) - if validator.on_fail_descriptor is OnFailAction.FIX: - raise ValueError( - """Fix is not supported for stream validation, - only noop and exception are supported.""" - ) - if validator.on_fail_descriptor is OnFailAction.FIX_REASK: - raise ValueError( - """Fix reask is not supported for stream validation, - only noop and exception are supported.""" - ) - if validator.on_fail_descriptor is OnFailAction.FILTER: - raise ValueError( - """Filter is not supported for stream validation, - only noop and exception are supported.""" - ) - if validator.on_fail_descriptor is OnFailAction.REFRAIN: - raise ValueError( - """Refrain is not supported for stream validation, - only noop and exception are supported.""" - ) - validator_logs = self.run_validator( - iteration, - validator, - value, - metadata, - absolute_property_path, - stream, - **kwargs, - ) - result = validator_logs.validation_result - - result = cast(ValidationResult, result) - if isinstance(result, FailResult): - rechecked_value = None - if validator.on_fail_descriptor == OnFailAction.FIX_REASK: - fixed_value = result.fix_value - rechecked_value = self.run_validator_sync( - validator, - fixed_value, - metadata, - validator_logs, - stream, - **kwargs, - ) - value = self.perform_correction( - [result], - value, - validator, - validator.on_fail_descriptor, - rechecked_value=rechecked_value, - ) - elif isinstance(result, PassResult): - if ( - validator.override_value_on_pass - and result.value_override is not result.ValueOverrideSentinel - ): - value = result.value_override - elif not stream: - raise RuntimeError(f"Unexpected result type {type(result)}") - - validator_logs.value_after_validation = value - if result and result.metadata is not None: - metadata = result.metadata - - if isinstance(value, (Refrain, Filter, ReAsk)): - return value, metadata - return value, metadata - - def validate( - self, - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, dict]: - ### - # NOTE: The way validation can be executed now is fundamentally wide open. - # Since validators are tracked against the JSONPaths for the - # properties they should be applied to, we have the following options: - # 1. Keep performing a Deep-First-Search - # - This is useful for backwards compatibility. - # - Is there something we gain by validating inside out? - # 2. Swith to a Breadth-First-Search - # - Possible, no obvious advantages - # 3. Run un-ordered - # - This would allow for true parallelism - # - Also means we're not unnecessarily iterating down through - # the object if there aren't any validations applied there. - ### - - child_ref_path = reference_path.replace(".*", "") - # Validate children first - if isinstance(value, List): - for index, child in enumerate(value): - abs_child_path = f"{absolute_path}.{index}" - ref_child_path = f"{child_ref_path}.*" - child_value, metadata = self.validate( - child, - metadata, - validator_map, - iteration, - abs_child_path, - ref_child_path, - ) - value[index] = child_value - elif isinstance(value, Dict): - for key in value: - child = value.get(key) - abs_child_path = f"{absolute_path}.{key}" - ref_child_path = f"{child_ref_path}.{key}" - child_value, metadata = self.validate( - child, - metadata, - validator_map, - iteration, - abs_child_path, - ref_child_path, - ) - value[key] = child_value - - # Then validate the parent value - value, metadata = self.run_validators( - iteration, - validator_map, - value, - metadata, - absolute_path, - reference_path, - stream=stream, - **kwargs, - ) - return value, metadata - - def validate_stream( - self, - value_stream: Iterable[Tuple[Any, bool]], - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - **kwargs, - ) -> Iterable[StreamValidationResult]: - # I assume validate stream doesn't need validate_dependents - # because right now we're only handling StringSchema - - # Validate the field - gen = self.run_validators_stream( - iteration, - validator_map, - value_stream, - metadata, - absolute_path, - reference_path, - **kwargs, - ) - return gen - - -class MultiprocMixin: - multiprocessing_executor: Optional[ProcessPoolExecutor] = None - process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - - def __init__(self): - if MultiprocMixin.multiprocessing_executor is None: - MultiprocMixin.multiprocessing_executor = ProcessPoolExecutor( - max_workers=MultiprocMixin.process_count - ) - - -class AsyncValidatorService(ValidatorServiceBase, MultiprocMixin): - async def run_validator_async( - self, - validator: Validator, - value: Any, - metadata: Dict, - stream: Optional[bool] = False, - *, - validation_session_id: str, - **kwargs, - ) -> ValidationResult: - result: ValidatorResult = self.execute_validator( - validator, - value, - metadata, - stream, - validation_session_id=validation_session_id, - **kwargs, - ) - if asyncio.iscoroutine(result): - result = await result - - if result is None: - result = PassResult() - else: - result = cast(ValidationResult, result) - return result - - async def run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - metadata: Dict, - absolute_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> ValidatorLogs: - validator_logs = self.before_run_validator( - iteration, validator, value, absolute_property_path - ) - - result = await self.run_validator_async( - validator, - value, - metadata, - stream, - validation_session_id=iteration.id, - **kwargs, - ) - - return self.after_run_validator(validator, validator_logs, result) - - def group_validators(self, validators: List[Validator]): - groups = itertools.groupby( - validators, key=lambda v: (v.on_fail_descriptor, v.override_value_on_pass) - ) - # NOTE: This isn't ordering anything. - # If we want to yield fix-like valiators first, - # then we need to extract them outside of the loop. - for (on_fail_descriptor, override_on_pass), group in groups: - if override_on_pass or on_fail_descriptor in [ - OnFailAction.FIX, - OnFailAction.FIX_REASK, - "custom", - ]: - for validator in group: - yield on_fail_descriptor, [validator] - else: - yield on_fail_descriptor, list(group) - - async def run_validators( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value: Any, - metadata: Dict, - absolute_property_path: str, - reference_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ): - loop = asyncio.get_running_loop() - validators = validator_map.get(reference_property_path, []) - for on_fail, validator_group in self.group_validators(validators): - parallel_tasks = [] - validators_logs: List[ValidatorLogs] = [] - for validator in validator_group: - if validator.run_in_separate_process: - # queue the validators to run in a separate process - parallel_tasks.append( - loop.run_in_executor( - self.multiprocessing_executor, - self.run_validator, - iteration, - validator, - value, - metadata, - absolute_property_path, - stream, - ) - ) - else: - # run the validators in the current process - result = await self.run_validator( - iteration, - validator, - value, - metadata, - absolute_property_path, - stream=stream, - **kwargs, - ) - validators_logs.append(result) - - # wait for the parallel tasks to finish - if parallel_tasks: - parallel_results = await asyncio.gather(*parallel_tasks) - awaited_results = [] - for res in parallel_results: - if asyncio.iscoroutine(res): - res = await res - awaited_results.append(res) - validators_logs.extend(awaited_results) - - # process the results, handle failures - fails = [ - logs - for logs in validators_logs - if isinstance(logs.validation_result, FailResult) - ] - if fails: - # NOTE: Ignoring type bc we know it's a FailResult - fail_results: List[FailResult] = [ - logs.validation_result # type: ignore - for logs in fails - ] - rechecked_value = None - validator: Validator = validator_group[0] - if validator.on_fail_descriptor == OnFailAction.FIX_REASK: - fixed_value = fail_results[0].fix_value - rechecked_value = await self.run_validator_async( - validator, - fixed_value, - fail_results[0].metadata or {}, - stream, - validation_session_id=iteration.id, - **kwargs, - ) - value = self.perform_correction( - fail_results, - value, - validator_group[0], - on_fail, - rechecked_value=rechecked_value, - ) - - # handle overrides - if ( - len(validator_group) == 1 - and validator_group[0].override_value_on_pass - and isinstance(validators_logs[0].validation_result, PassResult) - and validators_logs[0].validation_result.value_override - is not PassResult.ValueOverrideSentinel - ): - value = validators_logs[0].validation_result.value_override - - for logs in validators_logs: - logs.value_after_validation = value - - # return early if we have a filter, refrain, or reask - if isinstance(value, (Filter, Refrain, FieldReAsk)): - return value, metadata - - return value, metadata - - async def validate_children( - self, - value: Any, - metadata: Dict, - validator_map: ValidatorMap, - iteration: Iteration, - abs_parent_path: str, - ref_parent_path: str, - stream: Optional[bool] = False, - **kwargs, - ): - async def validate_child( - child_value: Any, *, key: Optional[str] = None, index: Optional[int] = None - ): - child_key = key or index - abs_child_path = f"{abs_parent_path}.{child_key}" - ref_child_path = ref_parent_path - if key is not None: - ref_child_path = f"{ref_child_path}.{key}" - elif index is not None: - ref_child_path = f"{ref_child_path}.*" - new_child_value, new_metadata = await self.async_validate( - child_value, - metadata, - validator_map, - iteration, - abs_child_path, - ref_child_path, - stream=stream, - **kwargs, - ) - return child_key, new_child_value, new_metadata - - tasks = [] - if isinstance(value, List): - for index, child in enumerate(value): - tasks.append(validate_child(child, index=index)) - elif isinstance(value, Dict): - for key in value: - child = value.get(key) - tasks.append(validate_child(child, key=key)) - - results = await asyncio.gather(*tasks) - - for key, child_value, child_metadata in results: - value[key] = child_value - # TODO address conflicting metadata entries - metadata = {**metadata, **child_metadata} - - return value, metadata - - async def async_validate( - self, - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, dict]: - child_ref_path = reference_path.replace(".*", "") - # Validate children first - if isinstance(value, List) or isinstance(value, Dict): - await self.validate_children( - value, - metadata, - validator_map, - iteration, - absolute_path, - child_ref_path, - stream=stream, - **kwargs, - ) - - # Then validate the parent value - value, metadata = await self.run_validators( - iteration, - validator_map, - value, - metadata, - absolute_path, - reference_path, - stream=stream, - **kwargs, - ) - - return value, metadata - - def validate( - self, - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, dict]: - # Run validate_async in an async loop - loop = asyncio.get_event_loop() - if loop.is_running(): - raise RuntimeError( - "Async event loop found, please call `validate_async` instead." - ) - value, metadata = loop.run_until_complete( - self.async_validate( - value, - metadata, - validator_map, - iteration, - absolute_path, - reference_path, - stream=stream, - **kwargs, - ) - ) - return value, metadata - - -def validate( - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - **kwargs, -): - if path is None: - path = "$" - - process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = None - - if process_count == 1: - validator_service = SequentialValidatorService(disable_tracer) - elif loop is not None and not loop.is_running(): - validator_service = AsyncValidatorService(disable_tracer) - else: - validator_service = SequentialValidatorService(disable_tracer) - - return validator_service.validate( - value, metadata, validator_map, iteration, path, path, **kwargs - ) - - -def validate_stream( - value_stream: Iterable[Tuple[Any, bool]], - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - **kwargs, -) -> Iterable[StreamValidationResult]: - if path is None: - path = "$" - sequential_validator_service = SequentialValidatorService(disable_tracer) - gen = sequential_validator_service.validate_stream( - value_stream, metadata, validator_map, iteration, path, path, **kwargs - ) - return gen - - -async def async_validate( - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - stream: Optional[bool] = False, - **kwargs, -) -> Tuple[Any, dict]: - if path is None: - path = "$" - validator_service = AsyncValidatorService(disable_tracer) - return await validator_service.async_validate( - value, metadata, validator_map, iteration, path, path, stream, **kwargs - ) - - -def post_process_validation( - validation_response: Any, - attempt_number: int, - iteration: Iteration, - output_type: OutputTypes, -) -> Any: - validated_response = apply_refrain(validation_response, output_type) - - # Remove all keys that have `Filter` values. - validated_response = apply_filters(validated_response) - - trace_validation_result( - validation_logs=iteration.validator_logs, attempt_number=attempt_number - ) - - return validated_response diff --git a/guardrails/validator_service/__init__.py b/guardrails/validator_service/__init__.py new file mode 100644 index 000000000..1ea4ef9c7 --- /dev/null +++ b/guardrails/validator_service/__init__.py @@ -0,0 +1,151 @@ +import asyncio +import os +from typing import Any, Iterable, Optional, Tuple +import warnings + +from guardrails.actions.filter import apply_filters +from guardrails.actions.refrain import apply_refrain +from guardrails.classes.history import Iteration +from guardrails.classes.output_type import OutputTypes +from guardrails.classes.validation.validation_result import ( + StreamValidationResult, +) +from guardrails.types import ValidatorMap +from guardrails.telemetry.legacy_validator_tracing import trace_validation_result +from guardrails.validator_service.async_validator_service import AsyncValidatorService +from guardrails.validator_service.sequential_validator_service import ( + SequentialValidatorService, +) + + +try: + import uvloop # type: ignore +except ImportError: + uvloop = None + + +def should_run_sync(): + process_count = os.environ.get("GUARDRAILS_PROCESS_COUNT") + if process_count is not None: + warnings.warn( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation, please use GUARDRAILS_RUN_SYNC instead.", + DeprecationWarning, + ) + process_count = int(process_count) + run_sync = os.environ.get("GUARDRAILS_RUN_SYNC", "false") + bool_values = ["true", "false"] + if run_sync.lower() not in bool_values: + warnings.warn( + f"GUARDRAILS_RUN_SYNC must be one of {bool_values}!" + f" Defaulting to 'false'." + ) + return process_count == 1 or run_sync.lower() == "true" + + +def get_loop() -> asyncio.AbstractEventLoop: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + raise RuntimeError("An event loop is already running.") + + if uvloop is not None: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + return asyncio.get_event_loop() + + +def validate( + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + **kwargs, +): + if path is None: + path = "$" + + loop = None + if should_run_sync(): + validator_service = SequentialValidatorService(disable_tracer) + else: + try: + loop = get_loop() + validator_service = AsyncValidatorService(disable_tracer) + except RuntimeError: + warnings.warn( + "Could not obtain an event loop." + " Falling back to synchronous validation." + ) + validator_service = SequentialValidatorService(disable_tracer) + + return validator_service.validate( + value, + metadata, + validator_map, + iteration, + path, + path, + loop=loop, # type: ignore It exists when we need it to. + **kwargs, + ) + + +def validate_stream( + value_stream: Iterable[Tuple[Any, bool]], + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + **kwargs, +) -> Iterable[StreamValidationResult]: + if path is None: + path = "$" + sequential_validator_service = SequentialValidatorService(disable_tracer) + gen = sequential_validator_service.validate_stream( + value_stream, metadata, validator_map, iteration, path, path, **kwargs + ) + return gen + + +async def async_validate( + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + stream: Optional[bool] = False, + **kwargs, +) -> Tuple[Any, dict]: + if path is None: + path = "$" + validator_service = AsyncValidatorService(disable_tracer) + return await validator_service.async_validate( + value, metadata, validator_map, iteration, path, path, stream, **kwargs + ) + + +def post_process_validation( + validation_response: Any, + attempt_number: int, + iteration: Iteration, + output_type: OutputTypes, +) -> Any: + validated_response = apply_refrain(validation_response, output_type) + + # Remove all keys that have `Filter` values. + validated_response = apply_filters(validated_response) + + trace_validation_result( + validation_logs=iteration.validator_logs, attempt_number=attempt_number + ) + + return validated_response diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py new file mode 100644 index 000000000..323e6c410 --- /dev/null +++ b/guardrails/validator_service/async_validator_service.py @@ -0,0 +1,315 @@ +import asyncio +from typing import Any, Awaitable, Coroutine, Dict, List, Optional, Tuple, Union + +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain +from guardrails.classes.history import Iteration +from guardrails.classes.validation.validation_result import ( + FailResult, + PassResult, + ValidationResult, +) +from guardrails.telemetry.validator_tracing import trace_async_validator +from guardrails.types import ValidatorMap, OnFailAction +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.actions.reask import FieldReAsk +from guardrails.validator_base import Validator +from guardrails.validator_service.validator_service_base import ( + ValidatorRun, + ValidatorServiceBase, +) + +ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] + + +class AsyncValidatorService(ValidatorServiceBase): + async def execute_validator( + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> Optional[ValidationResult]: + validate_func = ( + validator.async_validate_stream if stream else validator.async_validate + ) + traced_validator = trace_async_validator( + validator_name=validator.rail_alias, + obj_id=id(validator), + on_fail_descriptor=validator.on_fail_descriptor, + validation_session_id=validation_session_id, + **validator._kwargs, + )(validate_func) + if stream: + result = await traced_validator(value, metadata, **kwargs) + else: + result = await traced_validator(value, metadata) + return result + + async def run_validator_async( + self, + validator: Validator, + value: Any, + metadata: Dict, + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> ValidationResult: + result = await self.execute_validator( + validator, + value, + metadata, + stream, + validation_session_id=validation_session_id, + **kwargs, + ) + + if result is None: + result = PassResult() + return result + + async def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + absolute_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorRun: + validator_logs = self.before_run_validator( + iteration, validator, value, absolute_property_path + ) + + result = await self.run_validator_async( + validator, + value, + metadata, + stream, + validation_session_id=iteration.id, + **kwargs, + ) + + validator_logs = self.after_run_validator(validator, validator_logs, result) + + if isinstance(result, FailResult): + rechecked_value = None + if validator.on_fail_descriptor == OnFailAction.FIX_REASK: + fixed_value = result.fix_value + rechecked_value = await self.run_validator_async( + validator, + fixed_value, + result.metadata or {}, + stream, + validation_session_id=iteration.id, + **kwargs, + ) + value = self.perform_correction( + result, + value, + validator, + rechecked_value=rechecked_value, + ) + + # handle overrides + # QUESTION: Should this consider the rechecked_value as well? + elif ( + isinstance(result, PassResult) + and result.value_override is not PassResult.ValueOverrideSentinel + ): + value = result.value_override + + validator_logs.value_after_validation = value + + return ValidatorRun( + value=value, + metadata=metadata, + on_fail_action=validator.on_fail_descriptor, + validator_logs=validator_logs, + ) + + async def run_validators( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value: Any, + metadata: Dict, + absolute_property_path: str, + reference_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ): + validators = validator_map.get(reference_property_path, []) + coroutines: List[Coroutine[Any, Any, ValidatorRun]] = [] + validators_logs: List[ValidatorLogs] = [] + for validator in validators: + coroutines.append( + self.run_validator( + iteration, + validator, + value, + metadata, + absolute_property_path, + stream=stream, + **kwargs, + ) + ) + + results = await asyncio.gather(*coroutines) + reasks: List[FieldReAsk] = [] + for res in results: + validators_logs.append(res.validator_logs) + # QUESTION: Do we still want to do this here or handle it during the merge? + # return early if we have a filter, refrain, or reask + if isinstance(res.value, (Filter, Refrain)): + return res.value, metadata + elif isinstance(res.value, FieldReAsk): + reasks.append(res.value) + + # handle reasks + if len(reasks) > 0: + first_reask = reasks[0] + fail_results = [] + for reask in reasks: + fail_results.extend(reask.fail_results) + first_reask.fail_results = fail_results + return first_reask, metadata + + # merge the results + fix_values = [ + res.value + for res in results + if ( + isinstance(res.validator_logs.validation_result, FailResult) + and ( + res.on_fail_action == OnFailAction.FIX + or res.on_fail_action == OnFailAction.FIX_REASK + or res.on_fail_action == OnFailAction.CUSTOM + ) + ) + ] + if len(fix_values) > 0: + value = self.merge_results(value, fix_values) + + return value, metadata + + async def validate_children( + self, + value: Any, + metadata: Dict, + validator_map: ValidatorMap, + iteration: Iteration, + abs_parent_path: str, + ref_parent_path: str, + stream: Optional[bool] = False, + **kwargs, + ): + async def validate_child( + child_value: Any, *, key: Optional[str] = None, index: Optional[int] = None + ): + child_key = key or index + abs_child_path = f"{abs_parent_path}.{child_key}" + ref_child_path = ref_parent_path + if key is not None: + ref_child_path = f"{ref_child_path}.{key}" + elif index is not None: + ref_child_path = f"{ref_child_path}.*" + new_child_value, new_metadata = await self.async_validate( + child_value, + metadata, + validator_map, + iteration, + abs_child_path, + ref_child_path, + stream=stream, + **kwargs, + ) + return child_key, new_child_value, new_metadata + + coroutines = [] + if isinstance(value, List): + for index, child in enumerate(value): + coroutines.append(validate_child(child, index=index)) + elif isinstance(value, Dict): + for key in value: + child = value.get(key) + coroutines.append(validate_child(child, key=key)) + + results = await asyncio.gather(*coroutines) + + for key, child_value, child_metadata in results: + value[key] = child_value + # TODO address conflicting metadata entries + metadata = {**metadata, **child_metadata} + + return value, metadata + + async def async_validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, dict]: + child_ref_path = reference_path.replace(".*", "") + # Validate children first + if isinstance(value, List) or isinstance(value, Dict): + await self.validate_children( + value, + metadata, + validator_map, + iteration, + absolute_path, + child_ref_path, + stream=stream, + **kwargs, + ) + + # Then validate the parent value + value, metadata = await self.run_validators( + iteration, + validator_map, + value, + metadata, + absolute_path, + reference_path, + stream=stream, + **kwargs, + ) + + return value, metadata + + def validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + loop: asyncio.AbstractEventLoop, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, dict]: + value, metadata = loop.run_until_complete( + self.async_validate( + value, + metadata, + validator_map, + iteration, + absolute_path, + reference_path, + stream=stream, + **kwargs, + ) + ) + return value, metadata diff --git a/guardrails/validator_service/sequential_validator_service.py b/guardrails/validator_service/sequential_validator_service.py new file mode 100644 index 000000000..e86598277 --- /dev/null +++ b/guardrails/validator_service/sequential_validator_service.py @@ -0,0 +1,505 @@ +import asyncio +from typing import Any, Dict, Iterable, List, Optional, Tuple, cast + +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain +from guardrails.classes.history import Iteration +from guardrails.classes.validation.validation_result import ( + FailResult, + PassResult, + StreamValidationResult, + ValidationResult, +) +from guardrails.merge import merge +from guardrails.types import ValidatorMap, OnFailAction +from guardrails.utils.exception_utils import UserFacingException +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.actions.reask import ReAsk +from guardrails.validator_base import Validator +from guardrails.validator_service.validator_service_base import ValidatorServiceBase + + +class SequentialValidatorService(ValidatorServiceBase): + def run_validator_sync( + self, + validator: Validator, + value: Any, + metadata: Dict, + validator_logs: ValidatorLogs, + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> Optional[ValidationResult]: + result = self.execute_validator( + validator, + value, + metadata, + stream, + validation_session_id=validation_session_id, + **kwargs, + ) + if asyncio.iscoroutine(result): + raise UserFacingException( + ValueError( + "Cannot use async validators with a synchronous Guard! " + f"Either use AsyncGuard or remove {validator_logs.validator_name}." + ) + ) + if result is None: + return result + return cast(ValidationResult, result) + + def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorLogs: + validator_logs = self.before_run_validator( + iteration, validator, value, property_path + ) + + result = self.run_validator_sync( + validator, + value, + metadata, + validator_logs, + stream, + validation_session_id=iteration.id, + **kwargs, + ) + + return self.after_run_validator(validator, validator_logs, result) + + def run_validators_stream( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + for validator in validators: + if validator.on_fail_descriptor == OnFailAction.FIX: + return self.run_validators_stream_fix( + iteration, + validator_map, + value_stream, + metadata, + absolute_property_path, + reference_property_path, + **kwargs, + ) + return self.run_validators_stream_noop( + iteration, + validator_map, + value_stream, + metadata, + absolute_property_path, + reference_property_path, + **kwargs, + ) + + # requires at least 2 validators + def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: + current = new_values.pop() + print("Fmerging these:", new_values) + while len(new_values) > 0: + nextval = new_values.pop() + current = merge(current, nextval, original) + print("\nFmerge result:", current) + return current + + def run_validators_stream_fix( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + acc_output = "" + validator_partial_acc: dict[int, str] = {} + for validator in validators: + validator_partial_acc[id(validator)] = "" + last_chunk = None + last_chunk_validated = False + last_chunk_missing_validators = [] + refrain_triggered = False + for chunk, finished in value_stream: + original_text = chunk + acc_output += chunk + fixed_values = [] + last_chunk = chunk + last_chunk_missing_validators = [] + if refrain_triggered: + break + for validator in validators: + # reset chunk to original text + chunk = original_text + validator_logs = self.run_validator( + iteration, + validator, + chunk, + metadata, + absolute_property_path, + True, + remainder=finished, + **kwargs, + ) + result = validator_logs.validation_result + if result is None: + last_chunk_missing_validators.append(validator) + result = cast(ValidationResult, result) + # if we have a concrete result, log it in the validation map + if isinstance(result, FailResult): + is_filter = validator.on_fail_descriptor is OnFailAction.FILTER + is_refrain = validator.on_fail_descriptor is OnFailAction.REFRAIN + if is_filter or is_refrain: + refrain_triggered = True + break + rechecked_value = None + chunk = self.perform_correction( + result, + chunk, + validator, + rechecked_value=rechecked_value, + ) + fixed_values.append(chunk) + validator_partial_acc[id(validator)] += chunk # type: ignore + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + chunk = result.value_override + else: + chunk = result.validated_chunk + fixed_values.append(chunk) + validator_partial_acc[id(validator)] += chunk # type: ignore + validator_logs.value_after_validation = chunk + if result and result.metadata is not None: + metadata = result.metadata + + if refrain_triggered: + # if we have a failresult from a refrain/filter validator, yield empty + yield StreamValidationResult( + chunk="", original_text=acc_output, metadata=metadata + ) + else: + # if every validator has yielded a concrete value, merge and yield + # only merge and yield if all validators have run + # TODO: check if only 1 validator - then skip merging + if len(fixed_values) == len(validators): + last_chunk_validated = True + values_to_merge = [] + for validator in validators: + values_to_merge.append(validator_partial_acc[id(validator)]) + merged_value = self.multi_merge(acc_output, values_to_merge) + # merged_value = self.multi_merge(acc_output, values_to_merge) + # reset validator_partial_acc + for validator in validators: + validator_partial_acc[id(validator)] = "" + yield StreamValidationResult( + chunk=merged_value, original_text=acc_output, metadata=metadata + ) + acc_output = "" + else: + last_chunk_validated = False + # handle case where LLM doesn't yield finished flag + # we need to validate remainder of accumulated chunks + if not last_chunk_validated and not refrain_triggered: + original_text = last_chunk + for validator in last_chunk_missing_validators: + last_log = self.run_validator( + iteration, + validator, + # use empty chunk + # validator has already accumulated the chunk from the first loop + "", + metadata, + absolute_property_path, + True, + remainder=True, + **kwargs, + ) + result = last_log.validation_result + if isinstance(result, FailResult): + rechecked_value = None + last_chunk = self.perform_correction( + result, + last_chunk, + validator, + rechecked_value=rechecked_value, + ) + validator_partial_acc[id(validator)] += last_chunk # type: ignore + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + last_chunk = result.value_override + else: + last_chunk = result.validated_chunk + validator_partial_acc[id(validator)] += last_chunk # type: ignore + last_log.value_after_validation = last_chunk + if result and result.metadata is not None: + metadata = result.metadata + values_to_merge = [] + for validator in validators: + values_to_merge.append(validator_partial_acc[id(validator)]) + merged_value = self.multi_merge(acc_output, values_to_merge) + yield StreamValidationResult( + chunk=merged_value, + original_text=original_text, # type: ignore + metadata=metadata, # type: ignore + ) + # yield merged value + + def run_validators_stream_noop( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + # Validate the field + # TODO: Under what conditions do we yield? + # When we have at least one non-None value? + # When we have all non-None values? + # Does this depend on whether we are fix or not? + for chunk, finished in value_stream: + original_text = chunk + for validator in validators: + validator_logs = self.run_validator( + iteration, + validator, + chunk, + metadata, + absolute_property_path, + True, + **kwargs, + ) + result = validator_logs.validation_result + result = cast(ValidationResult, result) + + if isinstance(result, FailResult): + rechecked_value = None + chunk = self.perform_correction( + result, + chunk, + validator, + rechecked_value=rechecked_value, + ) + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + chunk = result.value_override + + validator_logs.value_after_validation = chunk + if result and result.metadata is not None: + metadata = result.metadata + # # TODO: Filter is no longer terminal, so we shouldn't yield, right? + # if isinstance(chunk, (Refrain, Filter, ReAsk)): + # yield chunk, metadata + yield StreamValidationResult( + chunk=chunk, original_text=original_text, metadata=metadata + ) + + def run_validators( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value: Any, + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, Dict[str, Any]]: + # Validate the field + validators = validator_map.get(reference_property_path, []) + for validator in validators: + if stream: + if validator.on_fail_descriptor is OnFailAction.REASK: + raise ValueError( + """Reask is not supported for stream validation, + only noop and exception are supported.""" + ) + if validator.on_fail_descriptor is OnFailAction.FIX: + raise ValueError( + """Fix is not supported for stream validation, + only noop and exception are supported.""" + ) + if validator.on_fail_descriptor is OnFailAction.FIX_REASK: + raise ValueError( + """Fix reask is not supported for stream validation, + only noop and exception are supported.""" + ) + if validator.on_fail_descriptor is OnFailAction.FILTER: + raise ValueError( + """Filter is not supported for stream validation, + only noop and exception are supported.""" + ) + if validator.on_fail_descriptor is OnFailAction.REFRAIN: + raise ValueError( + """Refrain is not supported for stream validation, + only noop and exception are supported.""" + ) + validator_logs = self.run_validator( + iteration, + validator, + value, + metadata, + absolute_property_path, + stream, + **kwargs, + ) + result = validator_logs.validation_result + + result = cast(ValidationResult, result) + if isinstance(result, FailResult): + rechecked_value = None + if validator.on_fail_descriptor == OnFailAction.FIX_REASK: + fixed_value = result.fix_value + rechecked_value = self.run_validator_sync( + validator, + fixed_value, + metadata, + validator_logs, + stream, + **kwargs, + ) + value = self.perform_correction( + result, + value, + validator, + rechecked_value=rechecked_value, + ) + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + value = result.value_override + elif not stream: + raise RuntimeError(f"Unexpected result type {type(result)}") + + validator_logs.value_after_validation = value + if result and result.metadata is not None: + metadata = result.metadata + + if isinstance(value, (Refrain, Filter, ReAsk)): + return value, metadata + return value, metadata + + def validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, dict]: + ### + # NOTE: The way validation can be executed now is fundamentally wide open. + # Since validators are tracked against the JSONPaths for the + # properties they should be applied to, we have the following options: + # 1. Keep performing a Deep-First-Search + # - This is useful for backwards compatibility. + # - Is there something we gain by validating inside out? + # 2. Swith to a Breadth-First-Search + # - Possible, no obvious advantages + # 3. Run un-ordered + # - This would allow for true parallelism + # - Also means we're not unnecessarily iterating down through + # the object if there aren't any validations applied there. + ### + + child_ref_path = reference_path.replace(".*", "") + # Validate children first + if isinstance(value, List): + for index, child in enumerate(value): + abs_child_path = f"{absolute_path}.{index}" + ref_child_path = f"{child_ref_path}.*" + child_value, metadata = self.validate( + child, + metadata, + validator_map, + iteration, + abs_child_path, + ref_child_path, + ) + value[index] = child_value + elif isinstance(value, Dict): + for key in value: + child = value.get(key) + abs_child_path = f"{absolute_path}.{key}" + ref_child_path = f"{child_ref_path}.{key}" + child_value, metadata = self.validate( + child, + metadata, + validator_map, + iteration, + abs_child_path, + ref_child_path, + ) + value[key] = child_value + + # Then validate the parent value + value, metadata = self.run_validators( + iteration, + validator_map, + value, + metadata, + absolute_path, + reference_path, + stream=stream, + **kwargs, + ) + return value, metadata + + def validate_stream( + self, + value_stream: Iterable[Tuple[Any, bool]], + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + # I assume validate stream doesn't need validate_dependents + # because right now we're only handling StringSchema + + # Validate the field + gen = self.run_validators_stream( + iteration, + validator_map, + value_stream, + metadata, + absolute_path, + reference_path, + **kwargs, + ) + return gen diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py new file mode 100644 index 000000000..d580a9ff0 --- /dev/null +++ b/guardrails/validator_service/validator_service_base.py @@ -0,0 +1,207 @@ +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Dict, Optional, Union + +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain +from guardrails.classes.history import Iteration +from guardrails.classes.validation.validation_result import ( + FailResult, + ValidationResult, +) +from guardrails.errors import ValidationError +from guardrails.merge import merge +from guardrails.types import OnFailAction +from guardrails.utils.hub_telemetry_utils import HubTelemetry +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.actions.reask import FieldReAsk +from guardrails.telemetry import trace_validator +from guardrails.utils.serialization_utils import deserialize, serialize +from guardrails.validator_base import Validator + +ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] + + +@dataclass +class ValidatorRun: + value: Any + metadata: Dict + on_fail_action: Union[str, OnFailAction] + validator_logs: ValidatorLogs + + +class ValidatorServiceBase: + """Base class for validator services.""" + + def __init__(self, disable_tracer: Optional[bool] = True): + self._disable_tracer = disable_tracer + + # NOTE: This is avoiding an issue with multiprocessing. + # If we wrap the validate methods at the class level or anytime before + # loop.run_in_executor is called, multiprocessing fails with a Pickling error. + # This is a well known issue without any real solutions. + # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, + # but is relatively unsupported on Windows. + def execute_validator( + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + # TODO: Make this just Optional[ValidationResult] + # Also maybe move to SequentialValidatorService + ) -> ValidatorResult: + validate_func = validator.validate_stream if stream else validator.validate + traced_validator = trace_validator( + validator_name=validator.rail_alias, + obj_id=id(validator), + on_fail_descriptor=validator.on_fail_descriptor, + validation_session_id=validation_session_id, + **validator._kwargs, + )(validate_func) + if stream: + result = traced_validator(value, metadata, **kwargs) + else: + result = traced_validator(value, metadata) + return result + + def perform_correction( + self, + result: FailResult, + value: Any, + validator: Validator, + rechecked_value: Optional[ValidationResult] = None, + ): + on_fail_descriptor = validator.on_fail_descriptor + if on_fail_descriptor == OnFailAction.FIX: + # FIXME: Should we still return fix_value if it is None? + # I think we should warn and return the original value. + return result.fix_value + elif on_fail_descriptor == OnFailAction.FIX_REASK: + # FIXME: Same thing here + fixed_value = result.fix_value + + if isinstance(rechecked_value, FailResult): + return FieldReAsk( + incorrect_value=fixed_value, + fail_results=[result], + ) + + return fixed_value + if on_fail_descriptor == OnFailAction.CUSTOM: + if validator.on_fail_method is None: + raise ValueError("on_fail is 'custom' but on_fail_method is None") + return validator.on_fail_method(value, result) + if on_fail_descriptor == OnFailAction.REASK: + return FieldReAsk( + incorrect_value=value, + fail_results=[result], + ) + if on_fail_descriptor == OnFailAction.EXCEPTION: + raise ValidationError( + "Validation failed for field with errors: " + + ", ".join([result.error_message]) + ) + if on_fail_descriptor == OnFailAction.FILTER: + return Filter() + if on_fail_descriptor == OnFailAction.REFRAIN: + return Refrain() + if on_fail_descriptor == OnFailAction.NOOP: + return value + else: + raise ValueError( + f"Invalid on_fail_descriptor {on_fail_descriptor}, " + f"expected 'fix' or 'exception'." + ) + + def before_run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + absolute_property_path: str, + ) -> ValidatorLogs: + validator_class_name = validator.__class__.__name__ + validator_logs = ValidatorLogs( + validator_name=validator_class_name, + value_before_validation=value, + registered_name=validator.rail_alias, + property_path=absolute_property_path, + # If we ever re-use validator instances across multiple properties, + # this will have to change. + instance_id=id(validator), + ) + iteration.outputs.validator_logs.append(validator_logs) + + start_time = datetime.now() + validator_logs.start_time = start_time + + return validator_logs + + def after_run_validator( + self, + validator: Validator, + validator_logs: ValidatorLogs, + result: Optional[ValidationResult], + ) -> ValidatorLogs: + end_time = datetime.now() + validator_logs.validation_result = result + validator_logs.end_time = end_time + + if not self._disable_tracer: + # Get HubTelemetry singleton and create a new span to + # log the validator usage + _hub_telemetry = HubTelemetry() + _hub_telemetry.create_new_span( + span_name="/validator_usage", + attributes=[ + ("validator_name", validator.rail_alias), + ("validator_on_fail", validator.on_fail_descriptor), + ( + "validator_result", + result.outcome + if isinstance(result, ValidationResult) + else None, + ), + ], + is_parent=False, # This span will have no children + has_parent=True, # This span has a parent + ) + + return validator_logs + + def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + absolute_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorRun: + raise NotImplementedError + + def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: + new_vals = deepcopy(new_values) + current = new_values.pop() + while len(new_values) > 0: + nextval = new_values.pop() + current = merge( + serialize(current), serialize(nextval), serialize(original_value) + ) + current = deserialize(original_value, current) + if current is None and original_value is not None: + # QUESTION: How do we escape hatch + # for when deserializing the merged value fails? + + # Should we return the original value? + # return original_value + + # Or just pick one of the new values? + return new_vals[0] + return current diff --git a/poetry.lock b/poetry.lock index e176aef6f..24b48dec6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7871,6 +7871,50 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvloop" +version = "0.20.0" +description = "Fast implementation of asyncio event loop on top of libuv" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:9ebafa0b96c62881d5cafa02d9da2e44c23f9f0cd829f3a32a6aff771449c996"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:35968fc697b0527a06e134999eef859b4034b37aebca537daeb598b9d45a137b"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b16696f10e59d7580979b420eedf6650010a4a9c3bd8113f24a103dfdb770b10"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b04d96188d365151d1af41fa2d23257b674e7ead68cfd61c725a422764062ae"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94707205efbe809dfa3a0d09c08bef1352f5d3d6612a506f10a319933757c006"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89e8d33bb88d7263f74dc57d69f0063e06b5a5ce50bb9a6b32f5fcbe655f9e73"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e50289c101495e0d1bb0bfcb4a60adde56e32f4449a67216a1ab2750aa84f037"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e237f9c1e8a00e7d9ddaa288e535dc337a39bcbf679f290aee9d26df9e72bce9"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746242cd703dc2b37f9d8b9f173749c15e9a918ddb021575a0205ec29a38d31e"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82edbfd3df39fb3d108fc079ebc461330f7c2e33dbd002d146bf7c445ba6e756"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:80dc1b139516be2077b3e57ce1cb65bfed09149e1d175e0478e7a987863b68f0"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f44af67bf39af25db4c1ac27e82e9665717f9c26af2369c404be865c8818dcf"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4b75f2950ddb6feed85336412b9a0c310a2edbcf4cf931aa5cfe29034829676d"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:77fbc69c287596880ecec2d4c7a62346bef08b6209749bf6ce8c22bbaca0239e"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6462c95f48e2d8d4c993a2950cd3d31ab061864d1c226bbf0ee2f1a8f36674b9"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649c33034979273fa71aa25d0fe120ad1777c551d8c4cd2c0c9851d88fcb13ab"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3a609780e942d43a275a617c0839d85f95c334bad29c4c0918252085113285b5"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aea15c78e0d9ad6555ed201344ae36db5c63d428818b4b2a42842b3870127c00"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0e94b221295b5e69de57a1bd4aeb0b3a29f61be6e1b478bb8a69a73377db7ba"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fee6044b64c965c425b65a4e17719953b96e065c5b7e09b599ff332bb2744bdf"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:265a99a2ff41a0fd56c19c3838b29bf54d1d177964c300dad388b27e84fd7847"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10c2956efcecb981bf9cfb8184d27d5d64b9033f917115a960b83f11bfa0d6b"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e7d61fe8e8d9335fac1bf8d5d82820b4808dd7a43020c149b63a1ada953d48a6"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2beee18efd33fa6fdb0976e18475a4042cd31c7433c866e8a09ab604c7c22ff2"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d8c36fdf3e02cec92aed2d44f63565ad1522a499c654f07935c8f9d04db69e95"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0fac7be202596c7126146660725157d4813aa29a4cc990fe51346f75ff8fde7"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d0fba61846f294bce41eb44d60d58136090ea2b5b99efd21cbdf4e21927c56a"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95720bae002ac357202e0d866128eb1ac82545bcf0b549b9abe91b5178d9b541"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:36c530d8fa03bfa7085af54a48f2ca16ab74df3ec7108a46ba82fd8b411a2315"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e97152983442b499d7a71e44f29baa75b3b02e65d9c44ba53b10338e98dedb66"}, + {file = "uvloop-0.20.0.tar.gz", hash = "sha256:4603ca714a754fc8d9b197e325db25b2ea045385e8a3ad05d3463de725fdf469"}, +] + +[package.extras] +docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] + [[package]] name = "virtualenv" version = "20.26.2" @@ -8369,9 +8413,10 @@ docs-build = ["docspec_python", "nbdoc", "pydoc-markdown"] huggingface = ["jsonformer", "torch", "transformers"] manifest = ["manifest-ml"] sql = ["sqlalchemy", "sqlglot", "sqlvalidator"] +uv = ["uvloop"] vectordb = ["faiss-cpu", "numpy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "aec41326aef66af046ce16d49c036fec48698032995f3f49df634b9da411caf7" +content-hash = "6253610141bb5686330057ae658550f9257aabe83ee7b279b783a7f4418a26a6" diff --git a/pyproject.toml b/pyproject.toml index 3020192d9..c10cf2a66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ guardrails-api-client = ">=0.3.8" diff-match-patch = "^20230430" guardrails-api = ">=0.0.1" mlflow = {version = ">=2.0.1", optional = true} +uvloop = {version = "^0.20.0", optional = true} [tool.poetry.extras] sql = ["sqlvalidator", "sqlalchemy", "sqlglot"] @@ -70,6 +71,7 @@ docs-build = ["nbdoc", "docspec_python", "pydoc-markdown"] huggingface = ["transformers", "torch", "jsonformer"] api = ["guardrails-api"] databricks = ["mlflow"] +uv = ["uvloop"] [tool.poetry.group.dev.dependencies] @@ -105,6 +107,7 @@ pillow = "^10.1.0" cairosvg = "^2.7.1" mkdocs-glightbox = "^0.3.4" + [[tool.poetry.source]] name = "PyPI" diff --git a/tests/conftest.py b/tests/conftest.py index fb98d40f2..5db54f65b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ def mock_span(): def mock_guard_hub_telemetry(): with patch("guardrails.guard.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -31,13 +32,17 @@ def mock_guard_hub_telemetry(): def mock_validator_base_hub_telemetry(): with patch("guardrails.validator_base.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @pytest.fixture(autouse=True) def mock_validator_service_hub_telemetry(): - with patch("guardrails.validator_service.HubTelemetry") as MockHubTelemetry: + with patch( + "guardrails.validator_service.validator_service_base.HubTelemetry" + ) as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -45,6 +50,7 @@ def mock_validator_service_hub_telemetry(): def mock_runner_hub_telemetry(): with patch("guardrails.run.runner.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry diff --git a/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt b/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt index a91f6ecbf..63fb66348 100644 --- a/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt +++ b/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt @@ -9,6 +9,7 @@ Generate a new response that corrects your old response such that the following - must be exactly two words - Value Hello a you and me is not lower case. +- Value has length greater than 10. Please return a shorter output, that is shorter than 10 characters. diff --git a/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py b/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py index 5c17e7188..463fe95c7 100644 --- a/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py +++ b/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py @@ -14,5 +14,10 @@ error_message="Value Hello a you\nand me is not lower case.", fix_value="hello a you\nand me", ), + FailResult( + outcome="fail", + error_message="Value has length greater than 10. Please return a shorter output, that is shorter than 10 characters.", # noqa: E501 + fix_value="Hello a yo", + ), ], ) diff --git a/tests/integration_tests/validator_service/test_async_validator_service_it.py b/tests/integration_tests/validator_service/test_async_validator_service_it.py new file mode 100644 index 000000000..643105e8b --- /dev/null +++ b/tests/integration_tests/validator_service/test_async_validator_service_it.py @@ -0,0 +1,275 @@ +import asyncio +import pytest +from time import sleep +from guardrails.validator_base import Validator, register_validator +from guardrails.classes.validation.validation_result import PassResult + + +@register_validator(name="test/validator1", data_type="string") +class Validator1(Validator): + def validate(self, value, metadata): + # This seems more realistic but is unreliable + # counter = 0 + # for i in range(100000000): + # counter += 1 + # This seems suspicious, but is consistent + sleep(0.3) + metadata["order"].append("test/validator1") + return PassResult() + + +@register_validator(name="test/validator2", data_type="string") +class Validator2(Validator): + def validate(self, value, metadata): + # counter = 0 + # for i in range(1): + # counter += 1 + sleep(0.1) + metadata["order"].append("test/validator2") + return PassResult() + + +@register_validator(name="test/validator3", data_type="string") +class Validator3(Validator): + def validate(self, value, metadata): + # counter = 0 + # for i in range(100000): + # counter += 1 + sleep(0.2) + metadata["order"].append("test/validator3") + return PassResult() + + +@register_validator(name="test/async_validator1", data_type="string") +class AsyncValidator1(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.3) + metadata["order"].append("test/async_validator1") + return PassResult() + + +@register_validator(name="test/async_validator2", data_type="string") +class AsyncValidator2(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.1) + metadata["order"].append("test/async_validator2") + return PassResult() + + +@register_validator(name="test/async_validator3", data_type="string") +class AsyncValidator3(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.2) + metadata["order"].append("test/async_validator3") + return PassResult() + + +class TestValidatorConcurrency: + @pytest.mark.asyncio + async def test_async_validate_with_sync_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + Validator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/validator3", "test/validator1"] + } + + def test_validate_with_sync_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + Validator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/validator3", "test/validator1"] + } + + @pytest.mark.asyncio + async def test_async_validate_with_async_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + AsyncValidator1(), + AsyncValidator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": [ + "test/async_validator2", + "test/async_validator3", + "test/async_validator1", + ] + } + + def test_validate_with_async_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + AsyncValidator1(), + AsyncValidator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": [ + "test/async_validator2", + "test/async_validator3", + "test/async_validator1", + ] + } + + @pytest.mark.asyncio + async def test_async_validate_with_mixed_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/async_validator3", "test/validator1"] + } + + def test_validate_with_mixed_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/async_validator3", "test/validator1"] + } diff --git a/tests/integration_tests/validator_service/test_init.py b/tests/integration_tests/validator_service/test_init.py new file mode 100644 index 000000000..cd08ba774 --- /dev/null +++ b/tests/integration_tests/validator_service/test_init.py @@ -0,0 +1,299 @@ +from asyncio import get_event_loop +from asyncio.unix_events import _UnixSelectorEventLoop +import os +import pytest + +from guardrails.validator_service import should_run_sync, get_loop +from guardrails.classes.history import Iteration + + +try: + import uvloop +except ImportError: + uvloop = None + + +class TestShouldRunSync: + def test_process_count_is_one(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "1" + if os.environ.get("GUARDRAILS_RUN_SYNC"): + del os.environ["GUARDRAILS_RUN_SYNC"] + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + + def test_process_count_is_2(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "2" + if os.environ.get("GUARDRAILS_RUN_SYNC"): + del os.environ["GUARDRAILS_RUN_SYNC"] + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is False + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + + def test_guardrails_run_sync_is_true(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_guardrails_run_sync_is_false(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "false" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + result = should_run_sync() + assert result is False + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_process_count_is_1_and_guardrails_run_sync_is_false(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "1" + os.environ["GUARDRAILS_RUN_SYNC"] = "false" + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_process_count_is_2_and_guardrails_run_sync_is_true(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "2" + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + +class TestGetLoop: + def test_raises_if_loop_is_running(self): + loop = get_event_loop() + + async def callback(): + # NOTE: This means only AsyncGuard will parallelize validators + # if it's called within an async function. + with pytest.raises(RuntimeError, match="An event loop is already running."): + get_loop() + + loop.run_until_complete(callback()) + + @pytest.mark.skipif(uvloop is None, reason="uvloop is not installed") + def test_uvloop_is_used_when_installed(self): + loop = get_loop() + assert isinstance(loop, uvloop.Loop) + + @pytest.mark.skipif(uvloop is not None, reason="uvloop is installed") + def test_asyncio_default_is_used_otherwise(self): + loop = get_loop() + assert isinstance(loop, _UnixSelectorEventLoop) + + +class TestValidate: + def test_forced_sync(self, mocker): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + from guardrails.validator_service import validate, SequentialValidatorService + + mocker.spy(SequentialValidatorService, "__init__") + mocker.spy(SequentialValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + SequentialValidatorService.__init__.assert_called_once() + SequentialValidatorService.validate.assert_called_once() + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_async(self, mocker): + from guardrails.validator_service import validate, AsyncValidatorService + + mocker.spy(AsyncValidatorService, "__init__") + mocker.spy(AsyncValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + AsyncValidatorService.__init__.assert_called_once() + AsyncValidatorService.validate.assert_called_once() + + def test_sync_busy_loop(self, mocker): + from guardrails.validator_service import validate, SequentialValidatorService + + mocker.spy(SequentialValidatorService, "__init__") + mocker.spy(SequentialValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + loop = get_event_loop() + + async def callback(): + with pytest.warns( + Warning, + match=( + "Could not obtain an event loop." + " Falling back to synchronous validation." + ), + ): + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + assert value == "value" + assert metadata == {} + + loop.run_until_complete(callback()) + + SequentialValidatorService.__init__.assert_called_once() + SequentialValidatorService.validate.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_validate(mocker): + from guardrails.validator_service import async_validate, AsyncValidatorService + + mocker.spy(AsyncValidatorService, "__init__") + mocker.spy(AsyncValidatorService, "async_validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = await async_validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + AsyncValidatorService.__init__.assert_called_once() + AsyncValidatorService.async_validate.assert_called_once() diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py index 78f63dd3c..d8b331b8d 100644 --- a/tests/unit_tests/test_async_guard.py +++ b/tests/unit_tests/test_async_guard.py @@ -445,55 +445,86 @@ def test_use_many_tuple(): ) -@pytest.mark.asyncio -async def test_validate(): - guard: AsyncGuard = ( - AsyncGuard() - .use(OneLine) - .use( - LowerCase(on_fail=OnFailAction.FIX), on="output" - ) # default on="output", still explicitly set - .use(TwoWords) - .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) - ) +# TODO: Move to integration tests; these are not unit tests... +class TestValidate: + @pytest.mark.asyncio + async def test_output_only_success(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output: str = "Oh Canada" # bc it meets our criteria - response = await guard.validate(llm_output) + llm_output: str = "Oh Canada" # bc it meets our criteria - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() - llm_output_2 = "Star Spangled Banner" # to stick with the theme + response = await guard.validate(llm_output) - response_2 = await guard.validate(llm_output_2) + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() - assert response_2.validation_passed is False - assert response_2.validated_output is None + @pytest.mark.asyncio + async def test_output_only_failure(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - # Test with a combination of prompt, output, instructions and msg_history validators - # Should still only use the output validators to validate the output - guard: AsyncGuard = ( - AsyncGuard() - .use(OneLine, on="prompt") - .use(LowerCase, on="instructions") - .use(UpperCase, on="msg_history") - .use(LowerCase, on="output", on_fail=OnFailAction.FIX) - .use(TwoWords, on="output") - .use(ValidLength, 0, 12, on="output") - ) + llm_output = "Star Spangled Banner" # to stick with the theme + + response = await guard.validate(llm_output) - llm_output: str = "Oh Canada" # bc it meets our criteria + assert response.validation_passed is False + assert response.validated_output is None - response = await guard.validate(llm_output) + @pytest.mark.asyncio + async def test_on_many_success(self): + # Test with a combination of prompt, output, + # instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = await guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + @pytest.mark.asyncio + async def test_on_many_failure(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output_2 = "Star Spangled Banner" # to stick with the theme + llm_output = "Star Spangled Banner" # to stick with the theme - response_2 = await guard.validate(llm_output_2) + response = await guard.validate(llm_output) - assert response_2.validation_passed is False - assert response_2.validated_output is None + assert response.validation_passed is False + assert response.validated_output is None def test_use_and_use_many(): diff --git a/tests/unit_tests/test_async_validator_service.py b/tests/unit_tests/test_async_validator_service.py deleted file mode 100644 index 2c566035a..000000000 --- a/tests/unit_tests/test_async_validator_service.py +++ /dev/null @@ -1,345 +0,0 @@ -import asyncio - -import pytest - -from guardrails.classes.history.iteration import Iteration -from guardrails.classes.validation.validator_logs import ValidatorLogs -from guardrails.validator_base import OnFailAction -from guardrails.validator_service import AsyncValidatorService -from guardrails.classes.validation.validation_result import PassResult - -from .mocks import MockLoop -from .mocks.mock_validator import create_mock_validator - -avs = AsyncValidatorService() - - -def test_validate_with_running_loop(mocker): - iteration = Iteration( - call_id="mock-call", - index=0, - ) - with pytest.raises(RuntimeError) as e_info: - mock_loop = MockLoop(True) - mocker.patch("asyncio.get_event_loop", return_value=mock_loop) - avs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert ( - str(e_info) - == "Async event loop found, please call `validate_async` instead." - ) - - -def test_validate_without_running_loop(mocker): - mock_loop = MockLoop(False) - mocker.patch("asyncio.get_event_loop", return_value=mock_loop) - async_validate_mock = mocker.MagicMock( - return_value=("async_validate_mock", {"async": True}) - ) - mocker.patch.object(avs, "async_validate", async_validate_mock) - loop_spy = mocker.spy(mock_loop, "run_until_complete") - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = avs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert loop_spy.call_count == 1 - async_validate_mock.assert_called_once_with( - True, {}, {}, iteration, "$", "$", stream=False - ) - assert validated_value == "async_validate_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_async_validate_with_children(mocker): - validate_children_mock = mocker.patch.object(avs, "validate_children") - - run_validators_mock = mocker.patch.object(avs, "run_validators") - run_validators_mock.return_value = ("run_validators_mock", {"async": True}) - - value = {"a": 1} - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.async_validate( - value=value, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert validate_children_mock.call_count == 1 - validate_children_mock.assert_called_once_with( - value, {}, {}, iteration, "$", "$", stream=False - ) - - assert run_validators_mock.call_count == 1 - run_validators_mock.assert_called_once_with( - iteration, {}, value, {}, "$", "$", stream=False - ) - - assert validated_value == "run_validators_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_async_validate_without_children(mocker): - validate_children_mock = mocker.patch.object(avs, "validate_children") - - run_validators_mock = mocker.patch.object(avs, "run_validators") - run_validators_mock.return_value = ("run_validators_mock", {"async": True}) - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.async_validate( - value="Hello world!", - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert validate_children_mock.call_count == 0 - - assert run_validators_mock.call_count == 1 - run_validators_mock.assert_called_once_with( - iteration, {}, "Hello world!", {}, "$", "$", stream=False - ) - - assert validated_value == "run_validators_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_validate_children(mocker): - async def mock_async_validate(v, md, *args, **kwargs): - return (f"new-{v}", md) - - async_validate_mock = mocker.patch.object( - avs, "async_validate", side_effect=mock_async_validate - ) - - gather_spy = mocker.spy(asyncio, "gather") - - validator_map = { - "$.mock-parent-key": [], - "$.mock-parent-key.child-one-key": [], - "$.mock-parent-key.child-two-key": [], - } - - value = { - "mock-parent-key": { - "child-one-key": "child-one-value", - "child-two-key": "child-two-value", - } - } - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.validate_children( - value=value.get("mock-parent-key"), - metadata={}, - validator_map=validator_map, - iteration=iteration, - abs_parent_path="$.mock-parent-key", - ref_parent_path="$.mock-parent-key", - ) - - assert gather_spy.call_count == 1 - - assert async_validate_mock.call_count == 2 - async_validate_mock.assert_any_call( - "child-one-value", - {}, - validator_map, - iteration, - "$.mock-parent-key.child-one-key", - "$.mock-parent-key.child-one-key", - stream=False, - ) - async_validate_mock.assert_any_call( - "child-two-value", - {}, - validator_map, - iteration, - "$.mock-parent-key.child-two-key", - "$.mock-parent-key.child-two-key", - stream=False, - ) - - assert validated_value == { - "child-one-key": "new-child-one-value", - "child-two-key": "new-child-two-value", - } - assert validated_metadata == {} - - -@pytest.mark.asyncio -async def test_run_validators(mocker): - group_validators_mock = mocker.patch.object(avs, "group_validators") - fix_validator_type = create_mock_validator("fix_validator", OnFailAction.FIX) - fix_validator = fix_validator_type() - noop_validator_type = create_mock_validator("noop_validator") - noop_validator_1 = noop_validator_type() - noop_validator_type = create_mock_validator("noop_validator") - noop_validator_2 = noop_validator_type() - noop_validator_2.run_in_separate_process = True - group_validators_mock.return_value = [ - (OnFailAction.FIX, [fix_validator]), - (OnFailAction.NOOP, [noop_validator_1, noop_validator_2]), - ] - - def mock_run_validator( - iteration, validator, value, metadata, property_path, stream - ): - return ValidatorLogs( - registered_name=validator.name, - validator_name=validator.name, - value_before_validation=value, - validation_result=PassResult(), - property_path=property_path, - ) - - run_validator_mock = mocker.patch.object( - avs, "run_validator", side_effect=mock_run_validator - ) - - mock_loop = MockLoop(True) - run_in_executor_spy = mocker.spy(mock_loop, "run_in_executor") - get_running_loop_mock = mocker.patch( - "asyncio.get_running_loop", return_value=mock_loop - ) - - async def mock_gather(*args): - return args - - asyancio_gather_mock = mocker.patch("asyncio.gather", side_effect=mock_gather) - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - value, metadata = await avs.run_validators( - iteration=iteration, - validator_map={}, - value=True, - metadata={}, - absolute_property_path="$", - reference_property_path="$", - ) - - assert get_running_loop_mock.call_count == 1 - - assert group_validators_mock.call_count == 1 - group_validators_mock.assert_called_once_with([]) - - assert run_in_executor_spy.call_count == 1 - run_in_executor_spy.assert_called_once_with( - avs.multiprocessing_executor, - run_validator_mock, - iteration, - noop_validator_2, - True, - {}, - "$", - False, - ) - - assert run_validator_mock.call_count == 3 - - assert asyancio_gather_mock.call_count == 1 - - assert value is True - assert metadata == {} - - -@pytest.mark.asyncio -async def test_run_validators_with_override(mocker): - group_validators_mock = mocker.patch.object(avs, "group_validators") - override_validator_type = create_mock_validator("override") - override_validator = override_validator_type() - override_validator.override_value_on_pass = True - - group_validators_mock.return_value = [("exception", [override_validator])] - - run_validator_mock = mocker.patch.object(avs, "run_validator") - run_validator_mock.return_value = ValidatorLogs( - registered_name="override", - validator_name="override", - value_before_validation="mock-value", - validation_result=PassResult(value_override="override"), - property_path="$", - ) - - mock_loop = MockLoop(True) - run_in_executor_spy = mocker.spy(mock_loop, "run_in_executor") - get_running_loop_mock = mocker.patch( - "asyncio.get_running_loop", return_value=mock_loop - ) - - asyancio_gather_mock = mocker.patch("asyncio.gather") - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - value, metadata = await avs.run_validators( - iteration=iteration, - validator_map={}, - value=True, - metadata={}, - absolute_property_path="$", - reference_property_path="$", - ) - - assert get_running_loop_mock.call_count == 1 - - assert group_validators_mock.call_count == 1 - group_validators_mock.assert_called_once_with([]) - - assert run_in_executor_spy.call_count == 0 - - assert run_validator_mock.call_count == 1 - - assert asyancio_gather_mock.call_count == 0 - - assert value == "override" - assert metadata == {} - - -# TODO -@pytest.mark.asyncio -async def test_run_validators_with_failures(mocker): - assert True is True diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index f3e951929..aad985d78 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -492,56 +492,82 @@ def test_use_many_tuple(): ) -def test_validate(): - guard: Guard = ( - Guard() - .use(OneLine) - .use( - LowerCase(on_fail=OnFailAction.FIX), on="output" - ) # default on="output", still explicitly set - .use(TwoWords) - .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) - ) +# TODO: Move to integration tests; these are not unit tests... +class TestValidate: + def test_output_only_success(self): + guard: Guard = ( + Guard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output: str = "Oh Canada" # bc it meets our criteria + llm_output: str = "Oh Canada" # bc it meets our criteria - response = guard.validate(llm_output) + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + def test_output_only_failure(self): + guard: Guard = ( + Guard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + llm_output = "Star Spangled Banner" # to stick with the theme - llm_output_2 = "Star Spangled Banner" # to stick with the theme + response = guard.validate(llm_output) - response_2 = guard.validate(llm_output_2) + assert response.validation_passed is False + assert response.validated_output is None - assert response_2.validation_passed is False - assert response_2.validated_output is None + def test_on_many_success(self): + # Test with a combination of prompt, output, + # instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - # Test with a combination of prompt, output, instructions and msg_history validators - # Should still only use the output validators to validate the output - guard: Guard = ( - Guard() - .use(OneLine, on="prompt") - .use(LowerCase, on="instructions") - .use(UpperCase, on="msg_history") - .use(LowerCase, on="output", on_fail=OnFailAction.FIX) - .use(TwoWords, on="output") - .use(ValidLength, 0, 12, on="output") - ) + llm_output: str = "Oh Canada" # bc it meets our criteria - llm_output: str = "Oh Canada" # bc it meets our criteria + response = guard.validate(llm_output) - response = guard.validate(llm_output) + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + def test_on_many_failure(self): + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output_2 = "Star Spangled Banner" # to stick with the theme + llm_output = "Star Spangled Banner" # to stick with the theme - response_2 = guard.validate(llm_output_2) + response = guard.validate(llm_output) - assert response_2.validation_passed is False - assert response_2.validated_output is None + assert response.validation_passed is False + assert response.validated_output is None def test_use_and_use_many(): diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 1b30aaee2..78069c349 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -1,4 +1,5 @@ import json +import re from typing import Any, Dict, List import pytest @@ -209,103 +210,213 @@ def test_to_xml_attrib(min, max, expected_xml): assert xml_validator == expected_xml -def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_deprecated_on_fail_handler(value: Any, fail_results: List[FailResult]): + return value + " deprecated" + + +def custom_fix_on_fail_handler(value: Any, fail_result: FailResult): return value + " " + value -def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]): - return FieldReAsk(incorrect_value=value, fail_results=fail_results) +def custom_reask_on_fail_handler(value: Any, fail_result: FailResult): + return FieldReAsk(incorrect_value=value, fail_results=[fail_result]) -def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_exception_on_fail_handler(value: Any, fail_result: FailResult): raise ValidationError("Something went wrong!") -def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_filter_on_fail_handler(value: Any, fail_result: FailResult): return Filter() -def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_refrain_on_fail_handler(value: Any, fail_result: FailResult): return Refrain() -@pytest.mark.parametrize( - "custom_reask_func, expected_result", - [ - ( - custom_fix_on_fail_handler, - {"pet_type": "dog dog", "name": "Fido"}, - ), - ( - custom_reask_on_fail_handler, - FieldReAsk( - incorrect_value="dog", - path=["pet_type"], - fail_results=[ - FailResult( - error_message="must be exactly two words", - fix_value="dog dog", - ) - ], +class TestCustomOnFailHandler: + def test_deprecated_on_fail_handler(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = {"pet_type": "dog deprecated", "name": "Fido"} + + with pytest.warns( + DeprecationWarning, + match=re.escape( # Becuase of square brackets in the message + "Specifying a List[FailResult] as the second argument" + " for a custom on_fail handler is deprecated. " + "Please use FailResult instead." ), - ), - ( - custom_exception_on_fail_handler, - ValidationError, - ), - ( - custom_filter_on_fail_handler, - None, - ), - ( - custom_refrain_on_fail_handler, - None, - ), - ], -) -# @pytest.mark.parametrize( -# "validator_spec", -# [ -# lambda val_func: TwoWords(on_fail=val_func), -# # This was never supported even pre-0.5.x. -# # Trying this with function calling will throw. -# lambda val_func: ("two-words", val_func), -# ], -# ) -def test_custom_on_fail_handler( - custom_reask_func, - expected_result, -): - prompt = """ - What kind of pet should I get and what should I name it? + ): + validator: Validator = TwoWords(on_fail=custom_deprecated_on_fail_handler) # type: ignore - ${gr.complete_json_suffix_v2} - """ + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") - output = """ - { - "pet_type": "dog", - "name": "Fido" - } - """ + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + assert response.validation_passed is True + assert response.validated_output == expected_result + + def test_custom_fix(self): + prompt = """ + What kind of pet should I get and what should I name it? - validator: Validator = TwoWords(on_fail=custom_reask_func) + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = {"pet_type": "dog dog", "name": "Fido"} - class Pet(BaseModel): - pet_type: str = Field(description="Species of pet", validators=[validator]) - name: str = Field(description="a unique pet name") + validator: Validator = TwoWords(on_fail=custom_fix_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + assert response.validation_passed is True + assert response.validated_output == expected_result + + def test_custom_reask(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = FieldReAsk( + incorrect_value="dog", + path=["pet_type"], + fail_results=[ + FailResult( + error_message="must be exactly two words", + fix_value="dog dog", + ) + ], + ) + + validator: Validator = TwoWords(on_fail=custom_reask_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + + # Why? Because we have a bad habit of applying every fix value + # to the output even if the user doesn't ask us to. + assert response.validation_passed is True + assert guard.history.first.iterations.first.reasks[0] == expected_result + + def test_custom_exception(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_exception_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) - guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) - if isinstance(expected_result, type) and issubclass(expected_result, Exception): with pytest.raises(ValidationError) as excinfo: guard.parse(output, num_reasks=0) assert str(excinfo.value) == "Something went wrong!" - else: + + def test_custom_filter(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_filter_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + + # NOTE: This doesn't seem right. + # Shouldn't pass if filtering is successful on the target property? + assert response.validation_passed is False + assert response.validated_output is None + + def test_custom_refrain(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_refrain_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + response = guard.parse(output, num_reasks=0) - if isinstance(expected_result, FieldReAsk): - assert guard.history.first.iterations.first.reasks[0] == expected_result - else: - assert response.validated_output == expected_result + + assert response.validation_passed is False + assert response.validated_output is None class Pet(BaseModel): diff --git a/tests/unit_tests/test_validator_service.py b/tests/unit_tests/test_validator_service.py deleted file mode 100644 index 36b723382..000000000 --- a/tests/unit_tests/test_validator_service.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest - -import guardrails.validator_service as vs -from guardrails.classes.history.iteration import Iteration - -from .mocks import MockAsyncValidatorService, MockLoop, MockSequentialValidatorService - - -iteration = Iteration( - call_id="mock-call", - index=0, -) - - -@pytest.mark.asyncio -async def test_async_validate(mocker): - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - validated_value, validated_metadata = await vs.async_validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockAsyncValidatorService.async_validate" - assert validated_metadata == {"async": True} - - -def test_validate_with_running_loop(mocker): - mockLoop = MockLoop(True) - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - mocker.patch("asyncio.get_event_loop", return_value=mockLoop) - - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockSequentialValidatorService.validate" - assert validated_metadata == {"sync": True} - - -def test_validate_without_running_loop(mocker): - mockLoop = MockLoop(False) - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - mocker.patch("asyncio.get_event_loop", return_value=mockLoop) - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockAsyncValidatorService.validate" - assert validated_metadata == {"sync": True} - - -def test_validate_loop_runtime_error(mocker): - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - # raise RuntimeError in `get_event_loop` - mocker.patch("asyncio.get_event_loop", side_effect=RuntimeError) - - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockSequentialValidatorService.validate" - assert validated_metadata == {"sync": True} diff --git a/tests/unit_tests/utils/test_serialization_utils.py b/tests/unit_tests/utils/test_serialization_utils.py new file mode 100644 index 000000000..948604dd2 --- /dev/null +++ b/tests/unit_tests/utils/test_serialization_utils.py @@ -0,0 +1,169 @@ +import pytest +from datetime import datetime +from guardrails.utils.serialization_utils import serialize, deserialize + + +class TestSerializeAndDeserialize: + def test_string(self): + data = "value" + + serialized_data = serialize(data) + assert serialized_data == '"value"' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_int(self): + data = 1 + + serialized_data = serialize(data) + assert serialized_data == "1" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_float(self): + data = 1.0 + + serialized_data = serialize(data) + assert serialized_data == "1.0" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_bool(self): + data = True + + serialized_data = serialize(data) + assert serialized_data == "true" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_datetime(self): + data = datetime(2024, 9, 10, 0, 0, 0) + + serialized_data = serialize(data) + assert serialized_data == '"2024-09-10T00:00:00"' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_dictionary(self): + data = {"key": "value"} + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_list(self): + data = ["value1", "value2"] + + serialized_data = serialize(data) + assert serialized_data == '["value1", "value2"]' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_simple_class(self): + class TestClass: + def __init__(self, key: str): + self.key = key + + data = TestClass("value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_classes_not_supported(self): + class TestClass: + def __init__(self, value: str): + self.value = value + + class TestClass2: + def __init__(self, value: TestClass): + self.value = value + + data = TestClass2(TestClass("value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + with pytest.raises(AttributeError) as excinfo: + assert deserialized_data.value.value == data.value.value + + assert str(excinfo.value) == "'dict' object has no attribute 'value'" + + def test_simple_dataclass(self): + from dataclasses import dataclass + + @dataclass + class TestClass: + key: str + + data = TestClass("value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_dataclasses_not_supported(self): + from dataclasses import dataclass + + @dataclass + class TestClass: + value: str + + @dataclass + class TestClass2: + value: TestClass + + data = TestClass2(TestClass("value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + with pytest.raises(AttributeError) as excinfo: + assert deserialized_data.value.value == data.value.value + + assert str(excinfo.value) == "'dict' object has no attribute 'value'" + + def test_simple_pydantic_model(self): + from pydantic import BaseModel + + class TestClass(BaseModel): + key: str + + data = TestClass(key="value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_pydantic_models(self): + from pydantic import BaseModel + + class TestClass(BaseModel): + value: str + + class TestClass2(BaseModel): + value: TestClass + + data = TestClass2(value=TestClass(value="value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.value.value == data.value.value diff --git a/tests/unit_tests/validator_service/test_async_validator_service.py b/tests/unit_tests/validator_service/test_async_validator_service.py new file mode 100644 index 000000000..7c0ebd368 --- /dev/null +++ b/tests/unit_tests/validator_service/test_async_validator_service.py @@ -0,0 +1,776 @@ +from datetime import datetime +from unittest.mock import MagicMock, call + +from guardrails.actions.filter import Filter +from guardrails.validator_service.validator_service_base import ValidatorRun +import pytest + +from guardrails.classes.history.iteration import Iteration +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.validator_base import OnFailAction, Validator +from guardrails.validator_service.async_validator_service import AsyncValidatorService +from guardrails.classes.validation.validation_result import FailResult, PassResult + + +avs = AsyncValidatorService() + + +def test_validate(mocker): + mock_loop = mocker.MagicMock() + mock_loop.run_until_complete = mocker.MagicMock(return_value=(True, {})) + # loop_spy = mocker.spy(mock_loop, "run_until_complete", return_value=(True, {})) + async_validate_mock = mocker.patch.object(avs, "async_validate") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + avs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=mock_loop, + ) + + assert mock_loop.run_until_complete.call_count == 1 + async_validate_mock.assert_called_once_with( + True, {}, {}, iteration, "$", "$", stream=False + ) + + +class TestAsyncValidate: + @pytest.mark.asyncio + async def test_with_dictionary(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object( + avs, "run_validators", return_value=("run_validators_mock", {"async": True}) + ) + + value = {"a": 1} + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value=value, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 1 + validate_children_mock.assert_called_once_with( + value, {}, {}, iteration, "$", "$", stream=False + ) + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, value, {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + @pytest.mark.asyncio + async def test_with_list(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object( + avs, "run_validators", return_value=("run_validators_mock", {"async": True}) + ) + + value = ["a"] + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value=value, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 1 + validate_children_mock.assert_called_once_with( + value, {}, {}, iteration, "$", "$", stream=False + ) + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, value, {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + @pytest.mark.asyncio + async def test_without_children(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object(avs, "run_validators") + run_validators_mock.return_value = ("run_validators_mock", {"async": True}) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value="Hello world!", + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 0 + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, "Hello world!", {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + +class TestValidateChildren: + @pytest.mark.asyncio + async def test_with_list(self, mocker): + mock_async_validate = mocker.patch.object( + avs, + "async_validate", + side_effect=[ + ( + "mock-child-1-value", + { + "mock-child-1-metadata": "child-1-metadata", + "mock-shared-metadata": "shared-metadata-1", + }, + ), + ( + "mock-child-2-value", + { + "mock-child-2-metadata": "child-2-metadata", + "mock-shared-metadata": "shared-metadata-2", + }, + ), + ], + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validator_map = ({"$.*": [MagicMock(spec=Validator)]},) + value, metadata = await avs.validate_children( + value=["mock-child-1", "mock-child-2"], + metadata={"mock-shared-metadata": "shared-metadata"}, + validator_map=validator_map, + iteration=iteration, + abs_parent_path="$", + ref_parent_path="$", + ) + + assert mock_async_validate.call_count == 2 + mock_async_validate.assert_has_calls( + [ + call( + "mock-child-1", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.0", + "$.*", + stream=False, + ), + call( + "mock-child-2", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.1", + "$.*", + stream=False, + ), + ] + ) + + assert value == ["mock-child-1-value", "mock-child-2-value"] + assert metadata == { + "mock-child-1-metadata": "child-1-metadata", + "mock-child-2-metadata": "child-2-metadata", + # NOTE: This is overriden based on who finishes last + "mock-shared-metadata": "shared-metadata-2", + } + + @pytest.mark.asyncio + async def test_with_dictionary(self, mocker): + mock_async_validate = mocker.patch.object( + avs, + "async_validate", + side_effect=[ + ( + "mock-child-1-value", + { + "mock-child-1-metadata": "child-1-metadata", + "mock-shared-metadata": "shared-metadata-1", + }, + ), + ( + "mock-child-2-value", + { + "mock-child-2-metadata": "child-2-metadata", + "mock-shared-metadata": "shared-metadata-2", + }, + ), + ], + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validator_map = ( + { + "$.child-1": [MagicMock(spec=Validator)], + "$.child-2": [MagicMock(spec=Validator)], + }, + ) + value, metadata = await avs.validate_children( + value={"child-1": "mock-child-1", "child-2": "mock-child-2"}, + metadata={"mock-shared-metadata": "shared-metadata"}, + validator_map=validator_map, + iteration=iteration, + abs_parent_path="$", + ref_parent_path="$", + ) + + assert mock_async_validate.call_count == 2 + mock_async_validate.assert_has_calls( + [ + call( + "mock-child-1", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.child-1", + "$.child-1", + stream=False, + ), + call( + "mock-child-2", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.child-2", + "$.child-2", + stream=False, + ), + ] + ) + + assert value == { + "child-1": "mock-child-1-value", + "child-2": "mock-child-2-value", + } + assert metadata == { + "mock-child-1-metadata": "child-1-metadata", + "mock-child-2-metadata": "child-2-metadata", + # NOTE: This is overriden based on who finishes last + "mock-shared-metadata": "shared-metadata-2", + } + + +class TestRunValidators: + @pytest.mark.asyncio + async def test_filter_exits_early(self, mocker): + mock_run_validator = mocker.patch.object( + avs, + "run_validator", + side_effect=[ + ValidatorRun( + value="mock-value", + metadata={}, + on_fail_action="noop", + validator_logs=ValidatorLogs( + registered_name="noop_validator", + validator_name="noop_validator", + value_before_validation="mock-value", + validation_result=PassResult(), + property_path="$", + ), + ), + ValidatorRun( + value=Filter(), + metadata={}, + on_fail_action="filter", + validator_logs=ValidatorLogs( + registered_name="filter_validator", + validator_name="filter_validator", + value_before_validation="mock-value", + validation_result=FailResult(error_message="mock-error"), + property_path="$", + ), + ), + ], + ) + mock_merge_results = mocker.patch.object(avs, "merge_results") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={ + "$": [ + MagicMock(spec=Validator), + MagicMock(spec=Validator), + ] + }, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 2 + assert mock_merge_results.call_count == 0 + + assert isinstance(value, Filter) + assert metadata == {} + + @pytest.mark.asyncio + async def test_calls_merge(self, mocker): + mock_run_validator = mocker.patch.object( + avs, + "run_validator", + side_effect=[ + ValidatorRun( + value="mock-value", + metadata={}, + on_fail_action="noop", + validator_logs=ValidatorLogs( + registered_name="noop_validator", + validator_name="noop_validator", + value_before_validation="mock-value", + validation_result=PassResult(), + property_path="$", + ), + ), + ValidatorRun( + value="mock-fix-value", + metadata={}, + on_fail_action="fix", + validator_logs=ValidatorLogs( + registered_name="fix_validator", + validator_name="fix_validator", + value_before_validation="mock-value", + validation_result=FailResult( + error_message="mock-error", fix_value="mock-fix-value" + ), + property_path="$", + ), + ), + ], + ) + mock_merge_results = mocker.patch.object( + avs, "merge_results", return_value="mock-fix-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={ + "$": [ + MagicMock(spec=Validator), + MagicMock(spec=Validator), + ] + }, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 2 + assert mock_merge_results.call_count == 1 + + assert value == "mock-fix-value" + assert metadata == {} + + @pytest.mark.asyncio + async def test_returns_value_if_no_results(self, mocker): + mock_run_validator = mocker.patch.object(avs, "run_validator") + mock_merge_results = mocker.patch.object(avs, "merge_results") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={}, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 0 + assert mock_merge_results.call_count == 0 + + assert value is True + assert metadata == {} + + +class TestRunValidator: + @pytest.mark.asyncio + async def test_pass_result(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = PassResult() + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_pass_result_with_override(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = PassResult(value_override="override") + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "override" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_fail_result(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = FailResult(error_message="mock-error") + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + mock_perform_correction = mocker.patch.object( + avs, "perform_correction", return_value="corrected-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert mock_perform_correction.call_count == 1 + mock_perform_correction.assert_called_once_with( + validation_result, "value", validator, rechecked_value=None + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "corrected-value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_fail_result_with_fix_reask(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = FailResult( + error_message="mock-error", fix_value="fixed-value" + ) + rechecked_result = PassResult() + mock_run_validator_async = mocker.patch.object( + avs, + "run_validator_async", + side_effect=[validation_result, rechecked_result], + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + mock_perform_correction = mocker.patch.object( + avs, "perform_correction", return_value="fixed-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = OnFailAction.FIX_REASK + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 2 + mock_run_validator_async.assert_has_calls( + [ + call(validator, "value", {}, False, validation_session_id=iteration.id), + call( + validator, + "fixed-value", + {}, + False, + validation_session_id=iteration.id, + ), + ] + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert mock_perform_correction.call_count == 1 + mock_perform_correction.assert_called_once_with( + validation_result, "value", validator, rechecked_value=rechecked_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "fixed-value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + +class TestRunValidatorAsync: + @pytest.mark.asyncio + async def test_happy_path(self, mocker): + mock_validator = MagicMock(spec=Validator) + + validation_result = PassResult() + mock_execute_validator = mocker.patch.object( + avs, "execute_validator", return_value=validation_result + ) + + result = await avs.run_validator_async( + validator=mock_validator, + value="value", + metadata={}, + stream=False, + validation_session_id="mock-session", + ) + + assert result == validation_result + + assert mock_execute_validator.call_count == 1 + mock_execute_validator.assert_called_once_with( + mock_validator, "value", {}, False, validation_session_id="mock-session" + ) + + @pytest.mark.asyncio + async def test_result_is_none(self, mocker): + mock_validator = MagicMock(spec=Validator) + + validation_result = None + mock_execute_validator = mocker.patch.object( + avs, "execute_validator", return_value=validation_result + ) + + result = await avs.run_validator_async( + validator=mock_validator, + value="value", + metadata={}, + stream=False, + validation_session_id="mock-session", + ) + + assert isinstance(result, PassResult) + + assert mock_execute_validator.call_count == 1 + mock_execute_validator.assert_called_once_with( + mock_validator, "value", {}, False, validation_session_id="mock-session" + ) diff --git a/tests/unit_tests/validator_service/test_validator_service.py b/tests/unit_tests/validator_service/test_validator_service.py new file mode 100644 index 000000000..7d0be1ae0 --- /dev/null +++ b/tests/unit_tests/validator_service/test_validator_service.py @@ -0,0 +1,173 @@ +from unittest.mock import AsyncMock +import pytest + +import guardrails.validator_service as vs +from guardrails.classes.history.iteration import Iteration + + +iteration = Iteration( + call_id="mock-call", + index=0, +) + + +class TestShouldRunSync: + def test_process_count_of_1(self, mocker): + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["1", "false"] + ) + assert vs.should_run_sync() is True + + def test_run_sync_set_to_true(self, mocker): + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["10", "True"] + ) + assert vs.should_run_sync() is True + + def test_should_run_sync_default(self, mocker): + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["10", "false"] + ) + assert vs.should_run_sync() is False + + +class TestGetLoop: + def test_get_loop_with_running_loop(self, mocker): + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + return_value="running loop", + ) + with pytest.raises(RuntimeError): + vs.get_loop() + + def test_get_loop_without_running_loop(self, mocker): + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + side_effect=RuntimeError, + ) + mocker.patch( + "guardrails.validator_service.asyncio.get_event_loop", + return_value="event loop", + ) + assert vs.get_loop() == "event loop" + + def test_get_loop_with_uvloop(self, mocker): + mocker.patch("guardrails.validator_service.uvloop") + mock_event_loop_policy = mocker.patch( + "guardrails.validator_service.uvloop.EventLoopPolicy" + ) + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + side_effect=RuntimeError, + ) + mocker.patch( + "guardrails.validator_service.asyncio.get_event_loop", + return_value="event loop", + ) + mock_set_event_loop_policy = mocker.patch("asyncio.set_event_loop_policy") + + assert vs.get_loop() == "event loop" + + mock_event_loop_policy.assert_called_once() + mock_set_event_loop_policy.assert_called_once_with( + mock_event_loop_policy.return_value + ) + + +class TestValidate: + def test_validate_with_sync(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=True) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop") + mocker.patch("guardrails.validator_service.warnings") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.SequentialValidatorService.assert_called_once_with(True) + vs.SequentialValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop=None, + ) + + def test_validate_with_async(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=False) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop", return_value="event loop") + mocker.patch("guardrails.validator_service.warnings") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.AsyncValidatorService.assert_called_once_with(True) + vs.AsyncValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop="event loop", + ) + + def test_validate_with_no_available_event_loop(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=False) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop", side_effect=RuntimeError) + mock_warn = mocker.patch("guardrails.validator_service.warnings.warn") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + mock_warn.assert_called_once_with( + "Could not obtain an event loop. Falling back to synchronous validation." + ) + + vs.SequentialValidatorService.assert_called_once_with(True) + vs.SequentialValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop=None, + ) + + +@pytest.mark.asyncio +async def test_async_validate(mocker): + mocker.patch( + "guardrails.validator_service.AsyncValidatorService", return_value=AsyncMock() + ) + await vs.async_validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.AsyncValidatorService.assert_called_once_with(True) + vs.AsyncValidatorService.return_value.async_validate.assert_called_once_with( + True, {}, {}, iteration, "$", "$", False + )