Skip to content

Commit 9dfa0db

Browse files
committed
stricter typing for custom on_fail methods
1 parent 541cbdd commit 9dfa0db

File tree

3 files changed

+70
-82
lines changed

3 files changed

+70
-82
lines changed

guardrails/validator_base.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from string import Template
1313
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
1414
from warnings import warn
15+
import warnings
1516

1617
import nltk
1718
import requests
@@ -26,6 +27,7 @@
2627
from guardrails.logger import logger
2728
from guardrails.remote_inference import remote_inference
2829
from guardrails.types.on_fail import OnFailAction
30+
from guardrails.utils.safe_get import safe_get
2931
from guardrails.utils.hub_telemetry_utils import HubTelemetry
3032

3133
# See: https://github.com/guardrails-ai/guardrails/issues/829
@@ -78,7 +80,7 @@ class Validator:
7880

7981
def __init__(
8082
self,
81-
on_fail: Optional[Union[Callable, OnFailAction]] = None,
83+
on_fail: Optional[Union[Callable[[Any, FailResult], Any], OnFailAction]] = None,
8284
**kwargs,
8385
):
8486
self.creds = Credentials.from_rc_file()
@@ -127,7 +129,7 @@ def __init__(
127129
self.on_fail_method = None
128130
else:
129131
self.on_fail_descriptor = OnFailAction.CUSTOM
130-
self.on_fail_method = on_fail
132+
self._set_on_fail_method(on_fail)
131133

132134
# Store the kwargs for the validator.
133135
self._kwargs = kwargs
@@ -136,6 +138,31 @@ def __init__(
136138
self.rail_alias in validators_registry
137139
), f"Validator {self.__class__.__name__} is not registered. "
138140

141+
def _set_on_fail_method(self, on_fail: Callable[[Any, FailResult], Any]):
142+
"""Set the on_fail method for the validator."""
143+
on_fail_args = inspect.getfullargspec(on_fail)
144+
second_arg = safe_get(on_fail_args.args, 1)
145+
if second_arg is None:
146+
raise ValueError(
147+
"The on_fail method must take two arguments: "
148+
"the value being validated and the FailResult."
149+
)
150+
second_arg_type = on_fail_args.annotations.get(second_arg)
151+
if second_arg_type == List[FailResult]:
152+
warnings.warn(
153+
"Specifying a List[FailResult] as the second argument"
154+
" for a custom on_fail handler is deprecated. "
155+
"Please use FailResult instead.",
156+
DeprecationWarning,
157+
)
158+
159+
def on_fail_wrapper(value: Any, fail_result: FailResult) -> Any:
160+
return on_fail(value, [fail_result]) # type: ignore
161+
162+
self.on_fail_method = on_fail_wrapper
163+
else:
164+
self.on_fail_method = on_fail
165+
139166
def _validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
140167
"""User implementable function.
141168

guardrails/validator_service/validator_service_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def perform_correction(
9595
if on_fail_descriptor == OnFailAction.CUSTOM:
9696
if validator.on_fail_method is None:
9797
raise ValueError("on_fail is 'custom' but on_fail_method is None")
98-
return validator.on_fail_method(value, [result])
98+
return validator.on_fail_method(value, result)
9999
if on_fail_descriptor == OnFailAction.REASK:
100100
return FieldReAsk(
101101
incorrect_value=value,

tests/unit_tests/test_validator_base.py

Lines changed: 40 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from typing import Any, Dict, List
34

45
import pytest
@@ -209,106 +210,66 @@ def test_to_xml_attrib(min, max, expected_xml):
209210
assert xml_validator == expected_xml
210211

211212

212-
def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]):
213+
def custom_deprecated_on_fail_handler(value: Any, fail_results: List[FailResult]):
214+
return value + " deprecated"
215+
216+
217+
def custom_fix_on_fail_handler(value: Any, fail_result: FailResult):
213218
return value + " " + value
214219

215220

216-
def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]):
217-
return FieldReAsk(incorrect_value=value, fail_results=fail_results)
221+
def custom_reask_on_fail_handler(value: Any, fail_result: FailResult):
222+
return FieldReAsk(incorrect_value=value, fail_results=[fail_result])
218223

219224

220-
def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]):
225+
def custom_exception_on_fail_handler(value: Any, fail_result: FailResult):
221226
raise ValidationError("Something went wrong!")
222227

223228

224-
def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]):
229+
def custom_filter_on_fail_handler(value: Any, fail_result: FailResult):
225230
return Filter()
226231

227232

228-
def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]):
233+
def custom_refrain_on_fail_handler(value: Any, fail_result: FailResult):
229234
return Refrain()
230235

231236

232-
@pytest.mark.parametrize(
233-
"custom_reask_func, expected_result",
234-
[
235-
(
236-
custom_fix_on_fail_handler,
237-
{"pet_type": "dog dog", "name": "Fido"},
238-
),
239-
(
240-
custom_reask_on_fail_handler,
241-
FieldReAsk(
242-
incorrect_value="dog",
243-
path=["pet_type"],
244-
fail_results=[
245-
FailResult(
246-
error_message="must be exactly two words",
247-
fix_value="dog dog",
248-
)
249-
],
250-
),
251-
),
252-
(
253-
custom_exception_on_fail_handler,
254-
ValidationError,
255-
),
256-
(
257-
custom_filter_on_fail_handler,
258-
None,
259-
),
260-
(
261-
custom_refrain_on_fail_handler,
262-
None,
263-
),
264-
],
265-
)
266-
# @pytest.mark.parametrize(
267-
# "validator_spec",
268-
# [
269-
# lambda val_func: TwoWords(on_fail=val_func),
270-
# # This was never supported even pre-0.5.x.
271-
# # Trying this with function calling will throw.
272-
# lambda val_func: ("two-words", val_func),
273-
# ],
274-
# )
275-
def test_custom_on_fail_handler(
276-
custom_reask_func,
277-
expected_result,
278-
):
279-
prompt = """
280-
What kind of pet should I get and what should I name it?
237+
class TestCustomOnFailHandler:
238+
def test_deprecated_on_fail_handler(self):
239+
prompt = """
240+
What kind of pet should I get and what should I name it?
281241
282-
${gr.complete_json_suffix_v2}
283-
"""
242+
${gr.complete_json_suffix_v2}
243+
"""
284244

285-
output = """
286-
{
287-
"pet_type": "dog",
288-
"name": "Fido"
289-
}
290-
"""
245+
output = """
246+
{
247+
"pet_type": "dog",
248+
"name": "Fido"
249+
}
250+
"""
251+
expected_result = {"pet_type": "dog deprecated", "name": "Fido"}
252+
253+
with pytest.warns(
254+
DeprecationWarning,
255+
match=re.escape( # Becuase of square brackets in the message
256+
"Specifying a List[FailResult] as the second argument"
257+
" for a custom on_fail handler is deprecated. "
258+
"Please use FailResult instead."
259+
),
260+
):
261+
validator: Validator = TwoWords(on_fail=custom_deprecated_on_fail_handler) # type: ignore
291262

292-
validator: Validator = TwoWords(on_fail=custom_reask_func)
263+
class Pet(BaseModel):
264+
pet_type: str = Field(description="Species of pet", validators=[validator])
265+
name: str = Field(description="a unique pet name")
293266

294-
class Pet(BaseModel):
295-
pet_type: str = Field(description="Species of pet", validators=[validator])
296-
name: str = Field(description="a unique pet name")
267+
guard = Guard.from_pydantic(output_class=Pet, prompt=prompt)
297268

298-
guard = Guard.from_pydantic(output_class=Pet, prompt=prompt)
299-
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
300-
with pytest.raises(ValidationError) as excinfo:
301-
guard.parse(output, num_reasks=0)
302-
assert str(excinfo.value) == "Something went wrong!"
303-
else:
304269
response = guard.parse(output, num_reasks=0)
305-
if isinstance(expected_result, FieldReAsk):
306-
assert guard.history.first.iterations.first.reasks[0] == expected_result
307-
else:
308-
assert response.validated_output == expected_result
309-
270+
assert response.validation_passed is True
271+
assert response.validated_output == expected_result
310272

311-
class TestCustomOnFailHandler:
312273
def test_custom_fix(self):
313274
prompt = """
314275
What kind of pet should I get and what should I name it?

0 commit comments

Comments
 (0)