33import os
44from concurrent .futures import ProcessPoolExecutor
55from datetime import datetime
6- from typing import Any , Dict , List , Optional , Tuple
6+ from typing import Any , Dict , List , Optional , Tuple , Union
77
88from guardrails .classes .history import Iteration
99from guardrails .datatypes import FieldValidation
1717from guardrails .validator_base import (
1818 FailResult ,
1919 Filter ,
20+ OnFailAction ,
2021 PassResult ,
2122 Refrain ,
2223 ValidationResult ,
@@ -59,13 +60,13 @@ def perform_correction(
5960 results : List [FailResult ],
6061 value : Any ,
6162 validator : Validator ,
62- on_fail_descriptor : str ,
63+ on_fail_descriptor : Union [ OnFailAction , str ] ,
6364 ):
64- if on_fail_descriptor == "fix" :
65+ if on_fail_descriptor == OnFailAction . FIX :
6566 # FIXME: Should we still return fix_value if it is None?
6667 # I think we should warn and return the original value.
6768 return results [0 ].fix_value
68- elif on_fail_descriptor == "fix_reask" :
69+ elif on_fail_descriptor == OnFailAction . FIX_REASK :
6970 # FIXME: Same thing here
7071 fixed_value = results [0 ].fix_value
7172 result = self .execute_validator (
@@ -83,21 +84,21 @@ def perform_correction(
8384 if validator .on_fail_method is None :
8485 raise ValueError ("on_fail is 'custom' but on_fail_method is None" )
8586 return validator .on_fail_method (value , results )
86- if on_fail_descriptor == "reask" :
87+ if on_fail_descriptor == OnFailAction . REASK :
8788 return FieldReAsk (
8889 incorrect_value = value ,
8990 fail_results = results ,
9091 )
91- if on_fail_descriptor == "exception" :
92+ if on_fail_descriptor == OnFailAction . EXCEPTION :
9293 raise ValidationError (
9394 "Validation failed for field with errors: "
9495 + ", " .join ([result .error_message for result in results ])
9596 )
96- if on_fail_descriptor == "filter" :
97+ if on_fail_descriptor == OnFailAction . FILTER :
9798 return Filter ()
98- if on_fail_descriptor == "refrain" :
99+ if on_fail_descriptor == OnFailAction . REFRAIN :
99100 return Refrain ()
100- if on_fail_descriptor == "noop" :
101+ if on_fail_descriptor == OnFailAction . NOOP :
101102 return value
102103 else :
103104 raise ValueError (
@@ -251,7 +252,11 @@ def group_validators(self, validators):
251252 validators , key = lambda v : (v .on_fail_descriptor , v .override_value_on_pass )
252253 )
253254 for (on_fail_descriptor , override_on_pass ), group in groups :
254- if override_on_pass or on_fail_descriptor in ["fix" , "fix_reask" , "custom" ]:
255+ if override_on_pass or on_fail_descriptor in [
256+ OnFailAction .FIX ,
257+ OnFailAction .FIX_REASK ,
258+ "custom" ,
259+ ]:
255260 for validator in group :
256261 yield on_fail_descriptor , [validator ]
257262 else :
0 commit comments