Skip to content

Commit 5ced4e6

Browse files
authored
Fix custom onfail handler (#421)
* formatattr: Pass through custom on_fail methods * validator_service: Pass all FailResults to custom fail handler instead of only first * test_validators: Add custom on_fail handler test
1 parent 6c32b35 commit 5ced4e6

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

guardrails/formatattr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Config:
3232
format: Optional[str]
3333

3434
# The on-fail handlers.
35-
on_fail_handlers: Dict[str, str]
35+
on_fail_handlers: Mapping[str, Union[str, Callable]]
3636

3737
# The validator arguments.
3838
validator_args: Mapping[str, Union[Dict[str, Any], List[Any]]]
@@ -65,7 +65,10 @@ def from_validators(
6565
validator_args = val.get_args()
6666
validators_with_args[validator_name] = validator_args
6767
# Set the on-fail attribute based on the on_fail value
68-
on_fail = val.on_fail_descriptor
68+
if val.on_fail_descriptor == "custom":
69+
on_fail = val.on_fail_method
70+
else:
71+
on_fail = val.on_fail_descriptor
6972
on_fails[val.rail_alias] = on_fail
7073
elif isinstance(val, tuple) and len(val) == 2:
7174
validator, on_fail = val

guardrails/validator_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def perform_correction(
4747
if on_fail_descriptor == "custom":
4848
if validator.on_fail_method is None:
4949
raise ValueError("on_fail is 'custom' but on_fail_method is None")
50-
return validator.on_fail_method(value, results[0])
50+
return validator.on_fail_method(value, results)
5151
if on_fail_descriptor == "reask":
5252
return FieldReAsk(
5353
incorrect_value=value,

tests/unit_tests/test_validators.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# noqa:W291
22
import os
3-
from typing import Any, Dict
3+
from typing import Any, Dict, List
44

55
import openai
66
import pytest
@@ -16,6 +16,7 @@
1616
PassResult,
1717
Refrain,
1818
ValidationResult,
19+
ValidatorError,
1920
check_refrain_in_dict,
2021
filter_in_dict,
2122
register_validator,
@@ -503,3 +504,103 @@ def test_detect_secrets():
503504
# Check if mod_value is same as code_snippet,
504505
# as there are no secrets in code_snippet
505506
assert mod_value == NO_SECRETS_CODE_SNIPPET
507+
508+
509+
def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]):
510+
return value + " " + value
511+
512+
513+
def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]):
514+
return FieldReAsk(incorrect_value=value, fail_results=fail_results)
515+
516+
517+
def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]):
518+
raise ValidatorError("Something went wrong!")
519+
520+
521+
def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]):
522+
return Filter()
523+
524+
525+
def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]):
526+
return Refrain()
527+
528+
529+
@pytest.mark.parametrize(
530+
"validator_func, expected_result",
531+
[
532+
(
533+
custom_fix_on_fail_handler,
534+
{"pet_type": "dog dog", "name": "Fido"},
535+
),
536+
(
537+
custom_reask_on_fail_handler,
538+
FieldReAsk(
539+
incorrect_value="dog",
540+
path=["pet_type"],
541+
fail_results=[
542+
FailResult(
543+
error_message="must be exactly two words",
544+
fix_value="dog",
545+
)
546+
],
547+
),
548+
),
549+
(
550+
custom_exception_on_fail_handler,
551+
ValidatorError,
552+
),
553+
(
554+
custom_filter_on_fail_handler,
555+
{"name": "Fido"},
556+
),
557+
(
558+
custom_refrain_on_fail_handler,
559+
{},
560+
),
561+
],
562+
)
563+
@pytest.mark.parametrize(
564+
"validator_spec",
565+
[
566+
lambda val_func: TwoWords(on_fail=val_func),
567+
lambda val_func: ("two-words", val_func),
568+
],
569+
)
570+
def test_custom_on_fail_handler(
571+
validator_spec,
572+
validator_func,
573+
expected_result,
574+
):
575+
prompt = """
576+
What kind of pet should I get and what should I name it?
577+
578+
${gr.complete_json_suffix_v2}
579+
"""
580+
581+
output = """
582+
{
583+
"pet_type": "dog",
584+
"name": "Fido"
585+
}
586+
"""
587+
588+
class Pet(BaseModel):
589+
pet_type: str = Field(
590+
description="Species of pet", validators=[validator_spec(validator_func)]
591+
)
592+
name: str = Field(description="a unique pet name")
593+
594+
guard = Guard.from_pydantic(output_class=Pet, prompt=prompt)
595+
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
596+
with pytest.raises(expected_result):
597+
guard.parse(output)
598+
else:
599+
validated_output = guard.parse(output, num_reasks=0)
600+
if isinstance(expected_result, FieldReAsk):
601+
assert (
602+
guard.guard_state.all_histories[0].history[0].reasks[0]
603+
== expected_result
604+
)
605+
else:
606+
assert validated_output == expected_result

0 commit comments

Comments
 (0)