Skip to content

Commit 7386fae

Browse files
authored
Merge pull request #1124 from guardrails-ai/async_guard_from_template_and_tests
tests and enable async guards by default from create template
2 parents 5ad9dbd + f3f39d1 commit 7386fae

File tree

5 files changed

+59
-6
lines changed

5 files changed

+59
-6
lines changed

guardrails/cli/create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def generate_template_config(
118118
guard_instantiations = []
119119

120120
for i, guard in enumerate(template["guards"]):
121-
guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])")
121+
guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])")
122122
guard_instantiations = "\n".join(guard_instantiations)
123123
# Interpolate variables
124124
output_content = template_content.format(

guardrails/cli/hub/template_config.py.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from guardrails import Guard
3+
from guardrails import AsyncGuard, Guard
44
from guardrails.hub import {VALIDATOR_IMPORTS}
55

66
try:

guardrails/guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
12151215
error="The response from the server was empty!",
12161216
)
12171217

1218-
# TODO renable this when we have history support in
1218+
# TODO reenable this when we have history support in
12191219
# multi-node server environments
12201220
# guard_history = self._api_client.get_history(
12211221
# self.name, validation_output.call_id

server_ci/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from guardrails import Guard
3+
from guardrails import AsyncGuard
44

55
try:
66
file_path = os.path.join(os.getcwd(), "guard-template.json")
@@ -11,4 +11,4 @@
1111
SystemExit(1)
1212

1313
# instantiate guards
14-
guard0 = Guard.from_dict(guards[0])
14+
guard0 = AsyncGuard.from_dict(guards[0])

server_ci/tests/test_server.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import openai
22
import os
33
import pytest
4-
from guardrails import Guard, settings
4+
from guardrails import AsyncGuard, Guard, settings
55

66
# OpenAI compatible Guardrails API Guard
77
openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/"
@@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed,
3232
assert validation_outcome.validated_output == validation_output
3333

3434

35+
@pytest.mark.asyncio
36+
async def test_async_guard_validation():
37+
settings.use_server = True
38+
guard = AsyncGuard(name="test-guard")
39+
40+
validation_outcome = await guard(
41+
model="gpt-4o-mini",
42+
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
43+
temperature=0.0,
44+
)
45+
46+
assert validation_outcome.validation_passed is True # noqa: E712
47+
assert validation_outcome.validated_output == "Citrus fruit,"
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_async_streaming_guard_validation():
52+
settings.use_server = True
53+
guard = AsyncGuard(name="test-guard")
54+
55+
async_iterator = await guard(
56+
model="gpt-4o-mini",
57+
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
58+
stream=True,
59+
temperature=0.0,
60+
)
61+
62+
full_output = ""
63+
async for validation_chunk in async_iterator:
64+
full_output += validation_chunk.validated_output
65+
66+
assert full_output == "Citrus fruit,Citrus fruit,"
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_sync_streaming_guard_validation():
71+
settings.use_server = True
72+
guard = Guard(name="test-guard")
73+
74+
iterator = guard(
75+
model="gpt-4o-mini",
76+
messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}],
77+
stream=True,
78+
temperature=0.0,
79+
)
80+
81+
full_output = ""
82+
for validation_chunk in iterator:
83+
full_output += validation_chunk.validated_output
84+
85+
assert full_output == "Citrus fruit,Citrus fruit,"
86+
87+
3588
@pytest.mark.parametrize(
3689
"message_content, output, validation_passed, error",
3790
[

0 commit comments

Comments
 (0)