|
| 1 | +# 3 tests |
| 2 | +# 1. Test streaming with OpenAICallable (mock openai.Completion.create) |
| 3 | +# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create) |
| 4 | +# 3. Test string schema streaming |
| 5 | +# Using the LowerCase Validator, and a custom validator to show new streaming behavior |
| 6 | +from typing import Any, Callable, Dict, List, Optional, Union |
| 7 | + |
| 8 | +import asyncio |
| 9 | +import pytest |
| 10 | + |
| 11 | +import guardrails as gd |
| 12 | +from guardrails.utils.casting_utils import to_int |
| 13 | +from guardrails.validator_base import ( |
| 14 | + ErrorSpan, |
| 15 | + FailResult, |
| 16 | + OnFailAction, |
| 17 | + PassResult, |
| 18 | + ValidationResult, |
| 19 | + Validator, |
| 20 | + register_validator, |
| 21 | +) |
| 22 | +from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII |
| 23 | + |
| 24 | + |
| 25 | +@register_validator(name="minsentencelength", data_type=["string", "list"]) |
| 26 | +class MinSentenceLengthValidator(Validator): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + min: Optional[int] = None, |
| 30 | + max: Optional[int] = None, |
| 31 | + on_fail: Optional[Callable] = None, |
| 32 | + ): |
| 33 | + super().__init__( |
| 34 | + on_fail=on_fail, |
| 35 | + min=min, |
| 36 | + max=max, |
| 37 | + ) |
| 38 | + self._min = to_int(min) |
| 39 | + self._max = to_int(max) |
| 40 | + |
| 41 | + def sentence_split(self, value): |
| 42 | + return list(map(lambda x: x + ".", value.split(".")[:-1])) |
| 43 | + |
| 44 | + def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult: |
| 45 | + sentences = self.sentence_split(value) |
| 46 | + error_spans = [] |
| 47 | + index = 0 |
| 48 | + for sentence in sentences: |
| 49 | + if len(sentence) < self._min: |
| 50 | + error_spans.append( |
| 51 | + ErrorSpan( |
| 52 | + start=index, |
| 53 | + end=index + len(sentence), |
| 54 | + reason=f"Sentence has length less than {self._min}. " |
| 55 | + f"Please return a longer output, " |
| 56 | + f"that is shorter than {self._max} characters.", |
| 57 | + ) |
| 58 | + ) |
| 59 | + if len(sentence) > self._max: |
| 60 | + error_spans.append( |
| 61 | + ErrorSpan( |
| 62 | + start=index, |
| 63 | + end=index + len(sentence), |
| 64 | + reason=f"Sentence has length greater than {self._max}. " |
| 65 | + f"Please return a shorter output, " |
| 66 | + f"that is shorter than {self._max} characters.", |
| 67 | + ) |
| 68 | + ) |
| 69 | + index = index + len(sentence) |
| 70 | + if len(error_spans) > 0: |
| 71 | + return FailResult( |
| 72 | + validated_chunk=value, |
| 73 | + error_spans=error_spans, |
| 74 | + error_message=f"Sentence has length less than {self._min}. " |
| 75 | + f"Please return a longer output, " |
| 76 | + f"that is shorter than {self._max} characters.", |
| 77 | + ) |
| 78 | + return PassResult(validated_chunk=value) |
| 79 | + |
| 80 | + def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult: |
| 81 | + return super().validate_stream(chunk, metadata, **kwargs) |
| 82 | + |
| 83 | + |
| 84 | +class Delta: |
| 85 | + content: str |
| 86 | + |
| 87 | + def __init__(self, content): |
| 88 | + self.content = content |
| 89 | + |
| 90 | + |
| 91 | +class Choice: |
| 92 | + text: str |
| 93 | + finish_reason: str |
| 94 | + index: int |
| 95 | + delta: Delta |
| 96 | + |
| 97 | + def __init__(self, text, delta, finish_reason, index=0): |
| 98 | + self.index = index |
| 99 | + self.delta = delta |
| 100 | + self.text = text |
| 101 | + self.finish_reason = finish_reason |
| 102 | + |
| 103 | + |
| 104 | +class MockOpenAIV1ChunkResponse: |
| 105 | + choices: list |
| 106 | + model: str |
| 107 | + |
| 108 | + def __init__(self, choices, model): |
| 109 | + self.choices = choices |
| 110 | + self.model = model |
| 111 | + |
| 112 | + |
| 113 | +class Response: |
| 114 | + def __init__(self, chunks): |
| 115 | + self.chunks = chunks |
| 116 | + |
| 117 | + async def gen(): |
| 118 | + for chunk in self.chunks: |
| 119 | + yield MockOpenAIV1ChunkResponse( |
| 120 | + choices=[ |
| 121 | + Choice( |
| 122 | + delta=Delta(content=chunk), |
| 123 | + text=chunk, |
| 124 | + finish_reason=None, |
| 125 | + ) |
| 126 | + ], |
| 127 | + model="OpenAI model name", |
| 128 | + ) |
| 129 | + await asyncio.sleep(0) # Yield control to the event loop |
| 130 | + |
| 131 | + self.completion_stream = gen() |
| 132 | + |
| 133 | + |
| 134 | +POETRY_CHUNKS = [ |
| 135 | + "John, under ", |
| 136 | + "GOLDEN bridges", |
| 137 | + ", roams,\n", |
| 138 | + "SAN Francisco's ", |
| 139 | + "hills, his HOME.\n", |
| 140 | + "Dreams of", |
| 141 | + " FOG, and salty AIR,\n", |
| 142 | + "In his HEART", |
| 143 | + ", he's always THERE.", |
| 144 | +] |
| 145 | + |
| 146 | + |
| 147 | +@pytest.mark.asyncio |
| 148 | +async def test_filter_behavior(mocker): |
| 149 | + mocker.patch( |
| 150 | + "litellm.acompletion", |
| 151 | + return_value=Response(POETRY_CHUNKS), |
| 152 | + ) |
| 153 | + |
| 154 | + guard = gd.AsyncGuard().use_many( |
| 155 | + MockDetectPII( |
| 156 | + on_fail=OnFailAction.FIX, |
| 157 | + pii_entities="pii", |
| 158 | + replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"}, |
| 159 | + ), |
| 160 | + LowerCase(on_fail=OnFailAction.FILTER), |
| 161 | + ) |
| 162 | + prompt = """Write me a 4 line poem about John in San Francisco. |
| 163 | + Make every third word all caps.""" |
| 164 | + gen = await guard( |
| 165 | + model="gpt-3.5-turbo", |
| 166 | + max_tokens=10, |
| 167 | + temperature=0, |
| 168 | + stream=True, |
| 169 | + prompt=prompt, |
| 170 | + ) |
| 171 | + |
| 172 | + text = "" |
| 173 | + final_res = None |
| 174 | + async for res in gen: |
| 175 | + final_res = res |
| 176 | + text += res.validated_output |
| 177 | + |
| 178 | + assert final_res.raw_llm_output == ", he's always THERE." |
| 179 | + # TODO deep dive this |
| 180 | + assert text == ( |
| 181 | + "John, under GOLDEN bridges, roams,\n" |
| 182 | + "SAN Francisco's Dreams of FOG, and salty AIR,\n" |
| 183 | + "In his HEART" |
| 184 | + ) |
0 commit comments