|
1 | 1 | import json |
| 2 | +import re |
2 | 3 | from typing import Any, Dict, List |
3 | 4 |
|
4 | 5 | import pytest |
@@ -209,106 +210,66 @@ def test_to_xml_attrib(min, max, expected_xml): |
209 | 210 | assert xml_validator == expected_xml |
210 | 211 |
|
211 | 212 |
|
212 | | -def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 213 | +def custom_deprecated_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 214 | + return value + " deprecated" |
| 215 | + |
| 216 | + |
| 217 | +def custom_fix_on_fail_handler(value: Any, fail_result: FailResult): |
213 | 218 | return value + " " + value |
214 | 219 |
|
215 | 220 |
|
216 | | -def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]): |
217 | | - return FieldReAsk(incorrect_value=value, fail_results=fail_results) |
| 221 | +def custom_reask_on_fail_handler(value: Any, fail_result: FailResult): |
| 222 | + return FieldReAsk(incorrect_value=value, fail_results=[fail_result]) |
218 | 223 |
|
219 | 224 |
|
220 | | -def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 225 | +def custom_exception_on_fail_handler(value: Any, fail_result: FailResult): |
221 | 226 | raise ValidationError("Something went wrong!") |
222 | 227 |
|
223 | 228 |
|
224 | | -def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 229 | +def custom_filter_on_fail_handler(value: Any, fail_result: FailResult): |
225 | 230 | return Filter() |
226 | 231 |
|
227 | 232 |
|
228 | | -def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 233 | +def custom_refrain_on_fail_handler(value: Any, fail_result: FailResult): |
229 | 234 | return Refrain() |
230 | 235 |
|
231 | 236 |
|
232 | | -@pytest.mark.parametrize( |
233 | | - "custom_reask_func, expected_result", |
234 | | - [ |
235 | | - ( |
236 | | - custom_fix_on_fail_handler, |
237 | | - {"pet_type": "dog dog", "name": "Fido"}, |
238 | | - ), |
239 | | - ( |
240 | | - custom_reask_on_fail_handler, |
241 | | - FieldReAsk( |
242 | | - incorrect_value="dog", |
243 | | - path=["pet_type"], |
244 | | - fail_results=[ |
245 | | - FailResult( |
246 | | - error_message="must be exactly two words", |
247 | | - fix_value="dog dog", |
248 | | - ) |
249 | | - ], |
250 | | - ), |
251 | | - ), |
252 | | - ( |
253 | | - custom_exception_on_fail_handler, |
254 | | - ValidationError, |
255 | | - ), |
256 | | - ( |
257 | | - custom_filter_on_fail_handler, |
258 | | - None, |
259 | | - ), |
260 | | - ( |
261 | | - custom_refrain_on_fail_handler, |
262 | | - None, |
263 | | - ), |
264 | | - ], |
265 | | -) |
266 | | -# @pytest.mark.parametrize( |
267 | | -# "validator_spec", |
268 | | -# [ |
269 | | -# lambda val_func: TwoWords(on_fail=val_func), |
270 | | -# # This was never supported even pre-0.5.x. |
271 | | -# # Trying this with function calling will throw. |
272 | | -# lambda val_func: ("two-words", val_func), |
273 | | -# ], |
274 | | -# ) |
275 | | -def test_custom_on_fail_handler( |
276 | | - custom_reask_func, |
277 | | - expected_result, |
278 | | -): |
279 | | - prompt = """ |
280 | | - What kind of pet should I get and what should I name it? |
| 237 | +class TestCustomOnFailHandler: |
| 238 | + def test_deprecated_on_fail_handler(self): |
| 239 | + prompt = """ |
| 240 | + What kind of pet should I get and what should I name it? |
281 | 241 |
|
282 | | - ${gr.complete_json_suffix_v2} |
283 | | - """ |
| 242 | + ${gr.complete_json_suffix_v2} |
| 243 | + """ |
284 | 244 |
|
285 | | - output = """ |
286 | | - { |
287 | | - "pet_type": "dog", |
288 | | - "name": "Fido" |
289 | | - } |
290 | | - """ |
| 245 | + output = """ |
| 246 | + { |
| 247 | + "pet_type": "dog", |
| 248 | + "name": "Fido" |
| 249 | + } |
| 250 | + """ |
| 251 | + expected_result = {"pet_type": "dog deprecated", "name": "Fido"} |
| 252 | + |
| 253 | + with pytest.warns( |
| 254 | + DeprecationWarning, |
| 255 | + match=re.escape( # Becuase of square brackets in the message |
| 256 | + "Specifying a List[FailResult] as the second argument" |
| 257 | + " for a custom on_fail handler is deprecated. " |
| 258 | + "Please use FailResult instead." |
| 259 | + ), |
| 260 | + ): |
| 261 | + validator: Validator = TwoWords(on_fail=custom_deprecated_on_fail_handler) # type: ignore |
291 | 262 |
|
292 | | - validator: Validator = TwoWords(on_fail=custom_reask_func) |
| 263 | + class Pet(BaseModel): |
| 264 | + pet_type: str = Field(description="Species of pet", validators=[validator]) |
| 265 | + name: str = Field(description="a unique pet name") |
293 | 266 |
|
294 | | - class Pet(BaseModel): |
295 | | - pet_type: str = Field(description="Species of pet", validators=[validator]) |
296 | | - name: str = Field(description="a unique pet name") |
| 267 | + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) |
297 | 268 |
|
298 | | - guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) |
299 | | - if isinstance(expected_result, type) and issubclass(expected_result, Exception): |
300 | | - with pytest.raises(ValidationError) as excinfo: |
301 | | - guard.parse(output, num_reasks=0) |
302 | | - assert str(excinfo.value) == "Something went wrong!" |
303 | | - else: |
304 | 269 | response = guard.parse(output, num_reasks=0) |
305 | | - if isinstance(expected_result, FieldReAsk): |
306 | | - assert guard.history.first.iterations.first.reasks[0] == expected_result |
307 | | - else: |
308 | | - assert response.validated_output == expected_result |
309 | | - |
| 270 | + assert response.validation_passed is True |
| 271 | + assert response.validated_output == expected_result |
310 | 272 |
|
311 | | -class TestCustomOnFailHandler: |
312 | 273 | def test_custom_fix(self): |
313 | 274 | prompt = """ |
314 | 275 | What kind of pet should I get and what should I name it? |
|
0 commit comments