Skip to content

Commit 762e1ea

Browse files
authored
Bug/onfail enum match (#701)
* fix on_fail enum matching * lint fix * fix noop behavior * fix async old openai tests * lint
1 parent a00e761 commit 762e1ea

File tree

6 files changed

+108
-9
lines changed

6 files changed

+108
-9
lines changed

guardrails/classes/history/call.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,28 @@ def guarded_output(self) -> Optional[Union[str, Dict]]:
269269
validated output may be "fixed" values that were corrected
270270
during validation.
271271
272-
This will only have a value if the Guard is in a passing state.
272+
This will only have a value if the Guard is in a passing state
273+
OR if the action is no-op.
273274
"""
274275
if self.status == pass_status:
275276
return self.fixed_output
277+
last_iteration = self.iterations.last
278+
if (
279+
not self.status == pass_status
280+
and last_iteration
281+
and last_iteration.failed_validations
282+
):
283+
# check that all failed validations are noop or none
284+
all_noop = True
285+
for failed_validation in last_iteration.failed_validations:
286+
if (
287+
failed_validation.value_after_validation
288+
is not failed_validation.value_before_validation
289+
):
290+
all_noop = False
291+
break
292+
if all_noop:
293+
return last_iteration.guarded_output
276294

277295
@property
278296
@deprecated(

guardrails/validator_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,23 @@ def __init__(
422422
""",
423423
FutureWarning,
424424
)
425+
self.on_fail_descriptor: Union[str, OnFailAction] = "custom"
425426

426427
if on_fail is None:
427428
on_fail = OnFailAction.NOOP
428429
if isinstance(on_fail, OnFailAction):
429430
self.on_fail_descriptor = on_fail
430431
self.on_fail_method = None
432+
elif (
433+
isinstance(on_fail, str)
434+
and OnFailAction.__members__.get(on_fail.upper()) is not None
435+
):
436+
self.on_fail_descriptor = (
437+
OnFailAction.__members__.get(on_fail.upper())
438+
or "" # this default isn't needed, it's just for pyright
439+
)
440+
self.on_fail_method = None
431441
else:
432-
self.on_fail_descriptor = "custom"
433442
self.on_fail_method = on_fail
434443

435444
# Store the kwargs for the validator.

tests/integration_tests/test_async.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ async def test_entity_extraction_with_noop(mocker):
9797
# assert final_output.validated_output == entity_extraction.VALIDATED_OUTPUT_NOOP
9898

9999
assert final_output.validation_passed is False
100-
assert final_output.validated_output is None
100+
assert final_output.validated_output is not None
101+
assert final_output.validated_output["fees"]
102+
assert final_output.validated_output["interest_rates"]
101103

102104
call = guard.history.first
103105

@@ -131,7 +133,9 @@ async def test_entity_extraction_with_noop_pydantic(mocker):
131133

132134
# Assertions are made on the guard state object.
133135
assert final_output.validation_passed is False
134-
assert final_output.validated_output is None
136+
assert final_output.validated_output is not None
137+
assert final_output.validated_output["fees"]
138+
assert final_output.validated_output["interest_rates"]
135139

136140
call = guard.history.first
137141

tests/integration_tests/test_guard.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,11 @@ def test_entity_extraction_with_noop(mocker, rail, prompt):
238238

239239
# Assertions are made on the guard state object.
240240
assert final_output.validation_passed is False
241-
assert final_output.validated_output is None
241+
assert (
242+
final_output.validated_output is not None
243+
and validated_output.__get__("fees")
244+
and validated_output.__get__("interest_rates")
245+
)
242246

243247
call = guard.history.first
244248

@@ -248,7 +252,11 @@ def test_entity_extraction_with_noop(mocker, rail, prompt):
248252
# For orginal prompt and output
249253
assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT
250254
assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT
251-
assert call.guarded_output is None
255+
assert (
256+
call.guarded_output is not None
257+
and call.guarded_output["fees"]
258+
and call.guarded_output["interest_rates"]
259+
)
252260
assert call.validation_response == entity_extraction.VALIDATED_OUTPUT_NOOP
253261

254262

@@ -843,8 +851,10 @@ def invoke(
843851
model = MockModel()
844852
guard = (
845853
Guard()
846-
.use(RegexMatch("Ice cream", match_type="search"), on="output")
847-
.use(ReadingTime(0.05)) # 3 seconds
854+
.use(
855+
RegexMatch("Ice cream", match_type="search", on_fail="refrain"), on="output"
856+
)
857+
.use(ReadingTime(0.05, on_fail="refrain")) # 3 seconds
848858
)
849859
output_parser = StrOutputParser()
850860

tests/integration_tests/test_multi_reask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ def test_multi_reask(mocker):
4343
# The output here fails some validators but passes others.
4444
# Since those that it fails in the end are noop fixes, validation fails.
4545
assert call.validation_response == python_rail.VALIDATOR_PARALLELISM_RESPONSE_3
46-
assert call.guarded_output is None
46+
assert call.guarded_output is not None and isinstance(call.guarded_output, str)
4747
assert call.status == "fail"
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Any, Dict
2+
3+
from guardrails.guard import Guard
4+
from guardrails.validator_base import (
5+
FailResult,
6+
ValidationResult,
7+
Validator,
8+
register_validator,
9+
)
10+
11+
12+
@register_validator("failure", "string")
13+
class FailureValidator(Validator):
14+
def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult:
15+
return FailResult(
16+
error_message=("Failed cuz this is the failure validator"),
17+
fix_value="FIXED",
18+
)
19+
20+
21+
# TODO: Add reask tests. Reask is fairly well covered through notebooks
22+
# but it's good to have it here too.
23+
def test_fix():
24+
guard = Guard().use(FailureValidator, on_fail="fix")
25+
res = guard.parse("hi")
26+
assert res.validated_output == "FIXED"
27+
assert res.validation_passed # Should this even be true though?
28+
29+
30+
def test_default_noop():
31+
guard = Guard().use(FailureValidator, on_fail="noop")
32+
res = guard.parse("hi")
33+
assert res.validated_output == "hi"
34+
assert not res.validation_passed
35+
36+
37+
def test_filter():
38+
guard = Guard().use(FailureValidator, on_fail="filter")
39+
res = guard.parse("hi")
40+
assert res.validated_output is None
41+
assert not res.validation_passed
42+
43+
44+
def test_refrain():
45+
guard = Guard().use(FailureValidator, on_fail="refrain")
46+
res = guard.parse("hi")
47+
assert res.validated_output is None
48+
assert not res.validation_passed
49+
50+
51+
def test_exception():
52+
guard = Guard().use(FailureValidator, on_fail="exception")
53+
try:
54+
guard.parse("hi")
55+
except Exception as e:
56+
assert "Failed cuz this is the failure validator" in str(e)
57+
else:
58+
assert False, "Expected an exception"

0 commit comments

Comments
 (0)