|
1 | 1 | import openai |
2 | 2 | import os |
3 | 3 | import pytest |
4 | | -from guardrails import Guard, settings |
| 4 | +from guardrails import AsyncGuard, Guard, settings |
5 | 5 |
|
6 | 6 | # OpenAI compatible Guardrails API Guard |
7 | 7 | openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/" |
@@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed, |
32 | 32 | assert validation_outcome.validated_output == validation_output |
33 | 33 |
|
34 | 34 |
|
| 35 | +@pytest.mark.asyncio |
| 36 | +async def test_async_guard_validation(): |
| 37 | + settings.use_server = True |
| 38 | + guard = AsyncGuard(name="test-guard") |
| 39 | + |
| 40 | + validation_outcome = await guard( |
| 41 | + model="gpt-4o-mini", |
| 42 | + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], |
| 43 | + temperature=0.0, |
| 44 | + ) |
| 45 | + |
| 46 | + assert validation_outcome.validation_passed is True # noqa: E712 |
| 47 | + assert validation_outcome.validated_output == "Citrus fruit," |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.asyncio |
| 51 | +async def test_async_streaming_guard_validation(): |
| 52 | + settings.use_server = True |
| 53 | + guard = AsyncGuard(name="test-guard") |
| 54 | + |
| 55 | + async_iterator = await guard( |
| 56 | + model="gpt-4o-mini", |
| 57 | + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], |
| 58 | + stream=True, |
| 59 | + temperature=0.0, |
| 60 | + ) |
| 61 | + |
| 62 | + full_output = "" |
| 63 | + async for validation_chunk in async_iterator: |
| 64 | + full_output += validation_chunk.validated_output |
| 65 | + |
| 66 | + assert full_output == "Citrus fruit,Citrus fruit," |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.asyncio |
| 70 | +async def test_sync_streaming_guard_validation(): |
| 71 | + settings.use_server = True |
| 72 | + guard = Guard(name="test-guard") |
| 73 | + |
| 74 | + iterator = guard( |
| 75 | + model="gpt-4o-mini", |
| 76 | + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], |
| 77 | + stream=True, |
| 78 | + temperature=0.0, |
| 79 | + ) |
| 80 | + |
| 81 | + full_output = "" |
| 82 | + for validation_chunk in iterator: |
| 83 | + full_output += validation_chunk.validated_output |
| 84 | + |
| 85 | + assert full_output == "Citrus fruit,Citrus fruit," |
| 86 | + |
| 87 | + |
35 | 88 | @pytest.mark.parametrize( |
36 | 89 | "message_content, output, validation_passed, error", |
37 | 90 | [ |
|
0 commit comments