|
| 1 | +# 3 tests |
| 2 | +# 1. Test streaming with OpenAICallable (mock openai.Completion.create) |
| 3 | +# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create) |
| 4 | +# 3. Test string schema streaming |
| 5 | +# Using the LowerCase Validator, and a custom validator to show new streaming behavior |
| 6 | +from typing import Any, Callable, Dict, List, Optional, Union |
| 7 | + |
| 8 | +import asyncio |
| 9 | +import pytest |
| 10 | +from pydantic import BaseModel, Field |
| 11 | + |
| 12 | +import guardrails as gd |
| 13 | +from guardrails.utils.casting_utils import to_int |
| 14 | +from guardrails.validator_base import ( |
| 15 | + ErrorSpan, |
| 16 | + FailResult, |
| 17 | + OnFailAction, |
| 18 | + PassResult, |
| 19 | + ValidationResult, |
| 20 | + Validator, |
| 21 | + register_validator, |
| 22 | +) |
| 23 | +from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII |
| 24 | + |
| 25 | +expected_raw_output = {"statement": "I am DOING well, and I HOPE you aRe too."} |
| 26 | +expected_fix_output = {"statement": "i am doing well, and i hope you are too."} |
| 27 | +expected_noop_output = {"statement": "I am DOING well, and I HOPE you aRe too."} |
| 28 | +expected_filter_refrain_output = {} |
| 29 | + |
| 30 | + |
| 31 | +@register_validator(name="minsentencelength", data_type=["string", "list"]) |
| 32 | +class MinSentenceLengthValidator(Validator): |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + min: Optional[int] = None, |
| 36 | + max: Optional[int] = None, |
| 37 | + on_fail: Optional[Callable] = None, |
| 38 | + ): |
| 39 | + super().__init__( |
| 40 | + on_fail=on_fail, |
| 41 | + min=min, |
| 42 | + max=max, |
| 43 | + ) |
| 44 | + self._min = to_int(min) |
| 45 | + self._max = to_int(max) |
| 46 | + |
| 47 | + def sentence_split(self, value): |
| 48 | + return list(map(lambda x: x + ".", value.split(".")[:-1])) |
| 49 | + |
| 50 | + def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult: |
| 51 | + sentences = self.sentence_split(value) |
| 52 | + error_spans = [] |
| 53 | + index = 0 |
| 54 | + for sentence in sentences: |
| 55 | + if len(sentence) < self._min: |
| 56 | + error_spans.append( |
| 57 | + ErrorSpan( |
| 58 | + start=index, |
| 59 | + end=index + len(sentence), |
| 60 | + reason=f"Sentence has length less than {self._min}. " |
| 61 | + f"Please return a longer output, " |
| 62 | + f"that is shorter than {self._max} characters.", |
| 63 | + ) |
| 64 | + ) |
| 65 | + if len(sentence) > self._max: |
| 66 | + error_spans.append( |
| 67 | + ErrorSpan( |
| 68 | + start=index, |
| 69 | + end=index + len(sentence), |
| 70 | + reason=f"Sentence has length greater than {self._max}. " |
| 71 | + f"Please return a shorter output, " |
| 72 | + f"that is shorter than {self._max} characters.", |
| 73 | + ) |
| 74 | + ) |
| 75 | + index = index + len(sentence) |
| 76 | + if len(error_spans) > 0: |
| 77 | + return FailResult( |
| 78 | + validated_chunk=value, |
| 79 | + error_spans=error_spans, |
| 80 | + error_message=f"Sentence has length less than {self._min}. " |
| 81 | + f"Please return a longer output, " |
| 82 | + f"that is shorter than {self._max} characters.", |
| 83 | + ) |
| 84 | + return PassResult(validated_chunk=value) |
| 85 | + |
| 86 | + def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult: |
| 87 | + return super().validate_stream(chunk, metadata, **kwargs) |
| 88 | + |
| 89 | + |
| 90 | +class Delta: |
| 91 | + content: str |
| 92 | + |
| 93 | + def __init__(self, content): |
| 94 | + self.content = content |
| 95 | + |
| 96 | + |
| 97 | +class Choice: |
| 98 | + text: str |
| 99 | + finish_reason: str |
| 100 | + index: int |
| 101 | + delta: Delta |
| 102 | + |
| 103 | + def __init__(self, text, delta, finish_reason, index=0): |
| 104 | + self.index = index |
| 105 | + self.delta = delta |
| 106 | + self.text = text |
| 107 | + self.finish_reason = finish_reason |
| 108 | + |
| 109 | + |
| 110 | +class MockOpenAIV1ChunkResponse: |
| 111 | + choices: list |
| 112 | + model: str |
| 113 | + |
| 114 | + def __init__(self, choices, model): |
| 115 | + self.choices = choices |
| 116 | + self.model = model |
| 117 | + |
| 118 | + |
| 119 | +class Response: |
| 120 | + def __init__(self, chunks): |
| 121 | + self.chunks = chunks |
| 122 | + |
| 123 | + async def gen(): |
| 124 | + for chunk in self.chunks: |
| 125 | + yield MockOpenAIV1ChunkResponse( |
| 126 | + choices=[ |
| 127 | + Choice( |
| 128 | + delta=Delta(content=chunk), |
| 129 | + text=chunk, |
| 130 | + finish_reason=None, |
| 131 | + ) |
| 132 | + ], |
| 133 | + model="OpenAI model name", |
| 134 | + ) |
| 135 | + await asyncio.sleep(0) # Yield control to the event loop |
| 136 | + |
| 137 | + self.completion_stream = gen() |
| 138 | + |
| 139 | + |
| 140 | +class LowerCaseFix(BaseModel): |
| 141 | + statement: str = Field( |
| 142 | + description="Validates whether the text is in lower case.", |
| 143 | + validators=[LowerCase(on_fail=OnFailAction.FIX)], |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +class LowerCaseNoop(BaseModel): |
| 148 | + statement: str = Field( |
| 149 | + description="Validates whether the text is in lower case.", |
| 150 | + validators=[LowerCase(on_fail=OnFailAction.NOOP)], |
| 151 | + ) |
| 152 | + |
| 153 | + |
| 154 | +class LowerCaseFilter(BaseModel): |
| 155 | + statement: str = Field( |
| 156 | + description="Validates whether the text is in lower case.", |
| 157 | + validators=[LowerCase(on_fail=OnFailAction.FILTER)], |
| 158 | + ) |
| 159 | + |
| 160 | + |
| 161 | +class LowerCaseRefrain(BaseModel): |
| 162 | + statement: str = Field( |
| 163 | + description="Validates whether the text is in lower case.", |
| 164 | + validators=[LowerCase(on_fail=OnFailAction.REFRAIN)], |
| 165 | + ) |
| 166 | + |
| 167 | + |
| 168 | +expected_minsentence_noop_output = "" |
| 169 | + |
| 170 | + |
| 171 | +class MinSentenceLengthNoOp(BaseModel): |
| 172 | + statement: str = Field( |
| 173 | + description="Validates whether the text is in lower case.", |
| 174 | + validators=[MinSentenceLengthValidator(on_fail=OnFailAction.NOOP)], |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +STR_PROMPT = "Say something nice to me." |
| 179 | + |
| 180 | +PROMPT = """ |
| 181 | +Say something nice to me. |
| 182 | +
|
| 183 | +${gr.complete_json_suffix} |
| 184 | +""" |
| 185 | + |
| 186 | +POETRY_CHUNKS = [ |
| 187 | + '"John, under ', |
| 188 | + "GOLDEN bridges", |
| 189 | + ", roams,\n", |
| 190 | + "SAN Francisco's ", |
| 191 | + "hills, his HOME.\n", |
| 192 | + "Dreams of", |
| 193 | + " FOG, and salty AIR,\n", |
| 194 | + "In his HEART", |
| 195 | + ", he's always THERE.", |
| 196 | +] |
| 197 | + |
| 198 | + |
| 199 | +@pytest.mark.asyncio |
| 200 | +async def test_filter_behavior(mocker): |
| 201 | + mocker.patch( |
| 202 | + "litellm.acompletion", |
| 203 | + return_value=Response(POETRY_CHUNKS), |
| 204 | + ) |
| 205 | + |
| 206 | + guard = gd.AsyncGuard().use_many( |
| 207 | + MockDetectPII( |
| 208 | + on_fail=OnFailAction.FIX, |
| 209 | + pii_entities="pii", |
| 210 | + replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"}, |
| 211 | + ), |
| 212 | + LowerCase(on_fail=OnFailAction.FILTER), |
| 213 | + ) |
| 214 | + prompt = """Write me a 4 line poem about John in San Francisco. |
| 215 | + Make every third word all caps.""" |
| 216 | + gen = await guard( |
| 217 | + model="gpt-3.5-turbo", |
| 218 | + max_tokens=10, |
| 219 | + temperature=0, |
| 220 | + stream=True, |
| 221 | + prompt=prompt, |
| 222 | + ) |
| 223 | + |
| 224 | + text = "" |
| 225 | + final_res = None |
| 226 | + async for res in gen: |
| 227 | + final_res = res |
| 228 | + text = text + res.validated_output |
| 229 | + |
| 230 | + assert final_res.raw_llm_output == ", he's always THERE." |
| 231 | + assert text == "" |
0 commit comments