Skip to content

Commit a0a693e

Browse files
committed
add async tests
1 parent 3524f2d commit a0a693e

File tree

3 files changed

+190
-32
lines changed

3 files changed

+190
-32
lines changed

guardrails/guard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from guardrails.rail import Rail
3030
from guardrails.run import AsyncRunner, Runner
3131
from guardrails.schema import Schema, StringSchema
32-
from guardrails.utils.reask_utils import sub_reasks_with_fixed_values
3332
from guardrails.validators import Validator
3433

3534
logger = logging.getLogger(__name__)

guardrails/run.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,7 @@
1111
from guardrails.prompt import Instructions, Prompt
1212
from guardrails.schema import Schema, StringSchema
1313
from guardrails.utils.llm_response import LLMResponse
14-
from guardrails.utils.reask_utils import (
15-
FieldReAsk,
16-
NonParseableReAsk,
17-
ReAsk,
18-
reasks_to_dict,
19-
sub_reasks_with_fixed_values,
20-
)
14+
from guardrails.utils.reask_utils import NonParseableReAsk, ReAsk, reasks_to_dict
2115
from guardrails.validator_base import ValidatorError
2216

2317
logger = logging.getLogger(__name__)
@@ -361,7 +355,8 @@ def prepare(
361355
iteration.outputs.validation_output = validated_msg_history
362356
if isinstance(validated_msg_history, ReAsk):
363357
raise ValidatorError(
364-
f"Message history validation failed: {validated_msg_history}"
358+
f"Message history validation failed: "
359+
f"{validated_msg_history}"
365360
)
366361
if validated_msg_history != msg_str:
367362
raise ValidatorError("Message history validation failed")
@@ -698,6 +693,8 @@ async def async_run(
698693
output_schema,
699694
prompt_params=prompt_params,
700695
)
696+
except (ValidatorError, ValueError) as e:
697+
raise e
701698
except Exception as e:
702699
error_message = str(e)
703700

@@ -934,7 +931,8 @@ async def async_prepare(
934931
)
935932
if isinstance(validated_msg_history, ReAsk):
936933
raise ValidatorError(
937-
f"Message history validation failed: {validated_msg_history}"
934+
f"Message history validation failed: "
935+
f"{validated_msg_history}"
938936
)
939937
if validated_msg_history != msg_str:
940938
raise ValidatorError("Message history validation failed")
@@ -963,6 +961,7 @@ async def async_prepare(
963961
validated_prompt = await prompt_schema.async_validate(
964962
iteration, prompt.source, self.metadata
965963
)
964+
iteration.outputs.validation_output = validated_prompt
966965
if validated_prompt is None:
967966
raise ValidatorError("Prompt validation failed")
968967
if isinstance(validated_prompt, ReAsk):
@@ -981,6 +980,7 @@ async def async_prepare(
981980
validated_instructions = await instructions_schema.async_validate(
982981
iteration, instructions.source, self.metadata
983982
)
983+
iteration.outputs.validation_output = validated_instructions
984984
if validated_instructions is None:
985985
raise ValidatorError("Instructions validation failed")
986986
if isinstance(validated_instructions, ReAsk):

tests/unit_tests/test_validators.py

Lines changed: 181 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from guardrails import Guard
1010
from guardrails.datatypes import DataType
1111
from guardrails.schema import StringSchema
12-
from guardrails.utils.openai_utils import OPENAI_VERSION, get_static_openai_create_func
12+
from guardrails.utils.openai_utils import (
13+
OPENAI_VERSION,
14+
get_static_openai_acreate_func,
15+
get_static_openai_create_func,
16+
)
1317
from guardrails.utils.reask_utils import FieldReAsk
1418
from guardrails.validator_base import (
1519
FailResult,
@@ -627,15 +631,7 @@ class Pet(BaseModel):
627631
name: str = Field(description="a unique pet name")
628632

629633

630-
def test_input_validation_fix(mocker):
631-
if OPENAI_VERSION.startswith("0"):
632-
mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion)
633-
else:
634-
mocker.patch(
635-
"openai.resources.chat.completions.Completions.create",
636-
new=mock_chat_completion,
637-
)
638-
634+
def test_input_validation_fix():
639635
# fix returns an amended value for prompt/instructions validation,
640636
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
641637
validators=[TwoWords(on_fail="fix")]
@@ -674,7 +670,7 @@ def test_input_validation_fix(mocker):
674670

675671
# rail prompt validation
676672
guard = Guard.from_rail_string(
677-
f"""
673+
"""
678674
<rail version="0.1">
679675
<prompt
680676
validators="two-words"
@@ -694,7 +690,7 @@ def test_input_validation_fix(mocker):
694690

695691
# rail instructions validation
696692
guard = Guard.from_rail_string(
697-
f"""
693+
"""
698694
<rail version="0.1">
699695
<prompt>
700696
This is not two words
@@ -716,6 +712,89 @@ def test_input_validation_fix(mocker):
716712
assert guard.history.first.iterations.first.outputs.validation_output == "This also"
717713

718714

715+
@pytest.mark.asyncio
716+
@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1")
717+
async def test_async_input_validation_fix():
718+
# fix returns an amended value for prompt/instructions validation,
719+
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
720+
validators=[TwoWords(on_fail="fix")]
721+
)
722+
await guard(
723+
get_static_openai_acreate_func(),
724+
prompt="What kind of pet should I get?",
725+
)
726+
assert guard.history.first.iterations.first.outputs.validation_output == "What kind"
727+
guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation(
728+
validators=[TwoWords(on_fail="fix")]
729+
)
730+
await guard(
731+
get_static_openai_acreate_func(),
732+
prompt="What kind of pet should I get and what should I name it?",
733+
instructions="But really, what kind of pet should I get?",
734+
)
735+
assert (
736+
guard.history.first.iterations.first.outputs.validation_output == "But really,"
737+
)
738+
739+
# but raises for msg_history validation
740+
with pytest.raises(ValidatorError):
741+
guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation(
742+
validators=[TwoWords(on_fail="fix")]
743+
)
744+
await guard(
745+
get_static_openai_acreate_func(),
746+
msg_history=[
747+
{
748+
"role": "user",
749+
"content": "What kind of pet should I get?",
750+
}
751+
],
752+
)
753+
754+
# rail prompt validation
755+
guard = Guard.from_rail_string(
756+
"""
757+
<rail version="0.1">
758+
<prompt
759+
validators="two-words"
760+
on-fail-two-words="fix"
761+
>
762+
This is not two words
763+
</prompt>
764+
<output type="string">
765+
</output>
766+
</rail>
767+
"""
768+
)
769+
await guard(
770+
get_static_openai_acreate_func(),
771+
)
772+
assert guard.history.first.iterations.first.outputs.validation_output == "This is"
773+
774+
# rail instructions validation
775+
guard = Guard.from_rail_string(
776+
"""
777+
<rail version="0.1">
778+
<prompt>
779+
This is not two words
780+
</prompt>
781+
<instructions
782+
validators="two-words"
783+
on-fail-two-words="fix"
784+
>
785+
This also is not two words
786+
</instructions>
787+
<output type="string">
788+
</output>
789+
</rail>
790+
"""
791+
)
792+
await guard(
793+
get_static_openai_acreate_func(),
794+
)
795+
assert guard.history.first.iterations.first.outputs.validation_output == "This also"
796+
797+
719798
@pytest.mark.parametrize(
720799
"on_fail",
721800
[
@@ -725,15 +804,7 @@ def test_input_validation_fix(mocker):
725804
"exception",
726805
],
727806
)
728-
def test_input_validation_fail(mocker, on_fail):
729-
if OPENAI_VERSION.startswith("0"):
730-
mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion)
731-
else:
732-
mocker.patch(
733-
"openai.resources.chat.completions.Completions.create",
734-
new=mock_chat_completion,
735-
)
736-
807+
def test_input_validation_fail(on_fail):
737808
# with_prompt_validation
738809
with pytest.raises(ValidatorError):
739810
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
@@ -771,7 +842,7 @@ def test_input_validation_fail(mocker, on_fail):
771842
guard = Guard.from_rail_string(
772843
f"""
773844
<rail version="0.1">
774-
<prompt
845+
<prompt
775846
validators="two-words"
776847
on-fail-two-words="{on_fail}"
777848
>
@@ -810,6 +881,94 @@ def test_input_validation_fail(mocker, on_fail):
810881
)
811882

812883

884+
@pytest.mark.parametrize(
885+
"on_fail",
886+
[
887+
"reask",
888+
"filter",
889+
"refrain",
890+
"exception",
891+
],
892+
)
893+
@pytest.mark.asyncio
894+
@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1")
895+
async def test_input_validation_fail_async(mocker, on_fail):
896+
# with_prompt_validation
897+
with pytest.raises(ValidatorError):
898+
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(
899+
validators=[TwoWords(on_fail=on_fail)]
900+
)
901+
await guard(
902+
get_static_openai_acreate_func(),
903+
prompt="What kind of pet should I get?",
904+
)
905+
# with_instructions_validation
906+
with pytest.raises(ValidatorError):
907+
guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation(
908+
validators=[TwoWords(on_fail=on_fail)]
909+
)
910+
await guard(
911+
get_static_openai_acreate_func(),
912+
prompt="What kind of pet should I get and what should I name it?",
913+
instructions="What kind of pet should I get?",
914+
)
915+
# with_msg_history_validation
916+
with pytest.raises(ValidatorError):
917+
guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation(
918+
validators=[TwoWords(on_fail=on_fail)]
919+
)
920+
await guard(
921+
get_static_openai_acreate_func(),
922+
msg_history=[
923+
{
924+
"role": "user",
925+
"content": "What kind of pet should I get?",
926+
}
927+
],
928+
)
929+
# rail prompt validation
930+
guard = Guard.from_rail_string(
931+
f"""
932+
<rail version="0.1">
933+
<prompt
934+
validators="two-words"
935+
on-fail-two-words="{on_fail}"
936+
>
937+
This is not two words
938+
</prompt>
939+
<output type="string">
940+
</output>
941+
</rail>
942+
"""
943+
)
944+
with pytest.raises(ValidatorError):
945+
await guard(
946+
get_static_openai_acreate_func(),
947+
)
948+
# rail instructions validation
949+
guard = Guard.from_rail_string(
950+
f"""
951+
<rail version="0.1">
952+
<prompt>
953+
This is not two words
954+
</prompt>
955+
<instructions
956+
validators="two-words"
957+
on-fail-two-words="{on_fail}"
958+
>
959+
This also is not two words
960+
</instructions>
961+
<output type="string">
962+
</output>
963+
</rail>
964+
"""
965+
)
966+
with pytest.raises(ValidatorError):
967+
await guard(
968+
get_static_openai_acreate_func(),
969+
)
970+
971+
813972
def test_input_validation_mismatch_raise():
814973
# prompt validation, msg_history argument
815974
guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation(

0 commit comments

Comments
 (0)