Skip to content

Commit 66a568b

Browse files
committed
more tests
1 parent da2dc63 commit 66a568b

File tree

9 files changed

+75
-60
lines changed

9 files changed

+75
-60
lines changed

tests/integration_tests/mock_llm_outputs.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class MockLiteLLMCallableOther(LiteLLMCallable):
1313
# NOTE: this class normally overrides `llm_providers.LiteLLMCallable`,
1414
# which compiles instructions and prompt into a single prompt;
1515
# here the instructions are passed into kwargs and ignored
16-
def _invoke_llm(self, prompt, *args, **kwargs):
16+
def _invoke_llm(self, messages, *args, **kwargs):
1717
"""Mock the OpenAI API call to Completion.create."""
1818

1919
_rail_to_compiled_prompt = { # noqa
@@ -43,16 +43,16 @@ def _invoke_llm(self, prompt, *args, **kwargs):
4343
}
4444

4545
try:
46-
output = mock_llm_responses[prompt]
46+
output = mock_llm_responses[messages[0]["content"]]
4747
return LLMResponse(
4848
output=output,
4949
prompt_token_count=123,
5050
response_token_count=1234,
5151
)
5252
except KeyError:
53-
print("Unrecognized prompt!")
54-
print(prompt)
55-
raise ValueError("Compiled prompt not found")
53+
print("Unrecognized messages!")
54+
print(messages)
55+
raise ValueError("Compiled messages not found")
5656

5757

5858
class MockAsyncLiteLLMCallable(AsyncLiteLLMCallable):
@@ -129,7 +129,10 @@ def _invoke_llm(
129129

130130
try:
131131
if messages:
132-
key = (messages[0]["content"], messages[1]["content"])
132+
if len(messages) == 2:
133+
key = (messages[0]["content"], messages[1]["content"])
134+
elif len(messages) == 1:
135+
key = (messages[0]["content"], None)
133136
out_text = mock_llm_responses[key]
134137
if prompt and instructions and not messages:
135138
out_text = mock_llm_responses[(prompt, instructions)]

tests/integration_tests/test_assets/lists_object.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
from pydantic import BaseModel
44

5-
LIST_PROMPT = """Create a list of items that may be found in a grocery store.
6-
7-
Json Output:
8-
9-
"""
5+
LIST_PROMPT = """Create a list of items that may be found in a grocery store."""
106

117

128
LIST_OUTPUT = """[{"name": "apple", "price": 1.0}, {"name": "banana", "price": 0.5}, {"name": "orange", "price": 1.5}]""" # noqa: E501
@@ -28,6 +24,8 @@ class Item(BaseModel):
2824
<float name="price" />
2925
</object>
3026
</output>
31-
<prompt>Create a list of items that may be found in a grocery store.</prompt>
27+
<messages>
28+
<message role="user">Create a list of items that may be found in a grocery store.</message>
29+
</messages>
3230
</rail>
3331
"""

tests/integration_tests/test_assets/python_rail/validator_parallelism.rail

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
on-fail-length="reask"
1212
/>
1313

14-
<prompt>
14+
<messages>
15+
<message role="user">
1516
Say hullo to my little friend
1617

1718
${gr.complete_string_suffix}
18-
</prompt>
19+
</message>
20+
</messages>
1921

2022
</rail>

tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_1.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,3 @@ Your generated response should satisfy the following properties:
1414

1515
Don't talk; just go.
1616

17-
18-
19-
String Output:
20-

tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,3 @@ Your generated response should satisfy the following properties:
2222
- length: min=1 max=10
2323

2424
Don't talk; just go.
25-
26-
27-
String Output:
28-

tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_3.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,3 @@ Your generated response should satisfy the following properties:
1818
- length: min=1 max=10
1919

2020
Don't talk; just go.
21-
22-
23-
String Output:
24-

tests/integration_tests/test_guard.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
from .mock_llm_outputs import (
29+
MockLiteLLMCallableOther,
2930
MockLiteLLMCallable,
3031
entity_extraction,
3132
lists_object,
@@ -173,10 +174,10 @@ def test_entity_extraction_with_reask(
173174
)
174175

175176
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
176-
guard = guard_initializer(rail, prompt)
177+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
177178

178179
final_output: ValidationOutcome = guard(
179-
llm_api=openai.completions.create,
180+
model="gpt-3.5-turbo",
180181
prompt_params={"document": content[:6000]},
181182
num_reasks=1,
182183
max_tokens=2000,
@@ -259,7 +260,7 @@ def test_entity_extraction_with_noop(mocker, rail, prompt):
259260
mocker.patch("guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallable)
260261

261262
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
262-
guard = guard_initializer(rail, prompt)
263+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
263264
final_output = guard(
264265
llm_api=openai.completions.create,
265266
prompt_params={"document": content[:6000]},
@@ -305,7 +306,7 @@ def test_entity_extraction_with_filter(mocker, rail, prompt):
305306
mocker.patch("guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallable)
306307

307308
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
308-
guard = guard_initializer(rail, prompt)
309+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
309310
final_output = guard(
310311
llm_api=openai.completions.create,
311312
prompt_params={"document": content[:6000]},
@@ -340,7 +341,7 @@ def test_entity_extraction_with_fix(mocker, rail, prompt):
340341
mocker.patch("guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallable)
341342

342343
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
343-
guard = guard_initializer(rail, prompt)
344+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
344345
final_output = guard(
345346
llm_api=openai.completions.create,
346347
prompt_params={"document": content[:6000]},
@@ -376,7 +377,7 @@ def test_entity_extraction_with_refrain(mocker, rail, prompt):
376377
mocker.patch("guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallable)
377378

378379
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
379-
guard = guard_initializer(rail, prompt)
380+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
380381
final_output = guard(
381382
llm_api=openai.completions.create,
382383
prompt_params={"document": content[:6000]},
@@ -857,11 +858,12 @@ def test_in_memory_validator_log_is_not_duplicated(mocker):
857858
try:
858859
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
859860
guard = guard_initializer(
860-
entity_extraction.PYDANTIC_RAIL_WITH_NOOP, entity_extraction.PYDANTIC_PROMPT
861+
entity_extraction.PYDANTIC_RAIL_WITH_NOOP,
862+
messages=[{"role": "user", "content": entity_extraction.PYDANTIC_PROMPT}],
861863
)
862864

863865
guard(
864-
llm_api=openai.completions.create,
866+
model="gpt-3.5-turbo",
865867
prompt_params={"document": content[:6000]},
866868
num_reasks=1,
867869
)
@@ -942,11 +944,13 @@ def test_guard_with_top_level_list_return_type(mocker, rail, prompt):
942944
# Create a Guard with a top level list return type
943945

944946
# Mock the LLM
945-
mocker.patch("guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallable)
947+
mocker.patch(
948+
"guardrails.llm_providers.LiteLLMCallable", new=MockLiteLLMCallableOther
949+
)
946950

947-
guard = guard_initializer(rail, prompt=prompt)
951+
guard = guard_initializer(rail, messages=[{"role": "user", "content": prompt}])
948952

949-
output = guard(llm_api=openai.completions.create)
953+
output = guard(model="gpt-3.5-turbo")
950954

951955
# Validate the output
952956
assert output.validated_output == [
@@ -1002,7 +1006,7 @@ def test_string_output(mocker):
10021006

10031007
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING)
10041008
final_output = guard(
1005-
llm_api=openai.completions.create,
1009+
model="gpt-3.5-turbo",
10061010
prompt_params={"ingredients": "tomato, cheese, sour cream"},
10071011
num_reasks=1,
10081012
)
@@ -1015,7 +1019,7 @@ def test_string_output(mocker):
10151019
assert call.iterations.length == 1
10161020

10171021
# For original prompt and output
1018-
assert call.compiled_prompt == string.COMPILED_PROMPT
1022+
assert call.compiled_messages[1]["content"]._source == string.COMPILED_PROMPT
10191023
assert call.raw_outputs.last == string.LLM_OUTPUT
10201024
assert mock_invoke_llm.call_count == 1
10211025
mock_invoke_llm = None
@@ -1138,7 +1142,7 @@ def test_string_reask(mocker):
11381142

11391143
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING_REASK)
11401144
final_output = guard(
1141-
llm_api=openai.completions.create,
1145+
model="gpt-3.5-turbo",
11421146
prompt_params={"ingredients": "tomato, cheese, sour cream"},
11431147
num_reasks=1,
11441148
max_tokens=100,
@@ -1152,15 +1156,18 @@ def test_string_reask(mocker):
11521156
assert call.iterations.length == 2
11531157

11541158
# For orginal prompt and output
1155-
assert call.compiled_instructions == string.COMPILED_INSTRUCTIONS
1156-
assert call.compiled_prompt == string.COMPILED_PROMPT
1159+
assert call.compiled_messages[0]["content"]._source == string.COMPILED_INSTRUCTIONS
1160+
assert call.compiled_messages[1]["content"]._source == string.COMPILED_PROMPT
11571161
assert call.iterations.first.raw_output == string.LLM_OUTPUT
11581162
assert call.iterations.first.validation_response == string.VALIDATED_OUTPUT_REASK
11591163

11601164
# For re-asked prompt and output
1161-
assert call.iterations.last.inputs.prompt == gd.Prompt(string.COMPILED_PROMPT_REASK)
1165+
assert (
1166+
call.iterations.last.inputs.messages[1]["content"]
1167+
== string.COMPILED_PROMPT_REASK
1168+
)
11621169
# Same thing as above
1163-
assert call.reask_prompts.last == string.COMPILED_PROMPT_REASK
1170+
assert call.reask_messages[0][1]["content"] == string.COMPILED_PROMPT_REASK
11641171

11651172
assert call.raw_outputs.last == string.LLM_OUTPUT_REASK
11661173
assert call.guarded_output == string.LLM_OUTPUT_REASK

tests/integration_tests/test_multi_reask.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import openai
21
import guardrails as gd
32
from guardrails.classes.llm.llm_response import LLMResponse
43

@@ -44,21 +43,30 @@ def test_multi_reask(mocker):
4443

4544
assert len(call.iterations) == 3
4645

47-
assert call.compiled_prompt == python_rail.VALIDATOR_PARALLELISM_PROMPT_1
46+
assert (
47+
call.compiled_messages[0]["content"]._source
48+
== python_rail.VALIDATOR_PARALLELISM_PROMPT_1
49+
)
4850
assert call.raw_outputs.first == python_rail.VALIDATOR_PARALLELISM_RESPONSE_1
4951
assert (
5052
call.iterations.first.validation_response
5153
== python_rail.VALIDATOR_PARALLELISM_REASK_1
5254
)
5355

54-
assert call.reask_prompts.first == python_rail.VALIDATOR_PARALLELISM_PROMPT_2
56+
assert (
57+
call.reask_messages[0][1]["content"]
58+
== python_rail.VALIDATOR_PARALLELISM_PROMPT_2
59+
)
5560
assert call.raw_outputs.at(1) == python_rail.VALIDATOR_PARALLELISM_RESPONSE_2
5661
assert (
5762
call.iterations.at(1).validation_response
5863
== python_rail.VALIDATOR_PARALLELISM_REASK_2
5964
)
6065

61-
assert call.reask_prompts.last == python_rail.VALIDATOR_PARALLELISM_PROMPT_3
66+
assert (
67+
call.reask_messages[1][1]["content"]
68+
== python_rail.VALIDATOR_PARALLELISM_PROMPT_3
69+
)
6270
assert call.raw_outputs.last == python_rail.VALIDATOR_PARALLELISM_RESPONSE_3
6371
# The output here fails some validators but passes others.
6472
# Since those that it fails in the end are noop fixes, validation fails.

tests/integration_tests/test_pydantic.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
from typing import Dict, List
3-
import openai
43
import pytest
54
from pydantic import BaseModel
65

@@ -36,10 +35,10 @@ def test_pydantic_with_reask(mocker):
3635
),
3736
]
3837

39-
guard = gd.Guard.from_pydantic(ListOfPeople,
40-
messages=[{
41-
"role": "user",
42-
"content": VALIDATED_RESPONSE_REASK_PROMPT}])
38+
guard = gd.Guard.from_pydantic(
39+
ListOfPeople,
40+
messages=[{"role": "user", "content": VALIDATED_RESPONSE_REASK_PROMPT}],
41+
)
4342
final_output = guard(
4443
model="text-davinci-003",
4544
max_tokens=512,
@@ -124,10 +123,15 @@ def test_pydantic_with_full_schema_reask(mocker):
124123
),
125124
]
126125

127-
guard = gd.Guard.from_pydantic(ListOfPeople, messages=[{
128-
"content": VALIDATED_RESPONSE_REASK_PROMPT,
129-
"role": "user",
130-
}])
126+
guard = gd.Guard.from_pydantic(
127+
ListOfPeople,
128+
messages=[
129+
{
130+
"content": VALIDATED_RESPONSE_REASK_PROMPT,
131+
"role": "user",
132+
}
133+
],
134+
)
131135
final_output = guard(
132136
model="gpt-3.5-turbo",
133137
max_tokens=512,
@@ -153,9 +157,14 @@ def test_pydantic_with_full_schema_reask(mocker):
153157
)
154158

155159
# For re-asked prompt and output
156-
assert call.iterations.at(1).inputs.messages[0]["content"]._source == pydantic.COMPILED_INSTRUCTIONS_CHAT
157-
assert call.iterations.at(1).inputs.messages[1]["content"]._source == pydantic.COMPILED_PROMPT_FULL_REASK
158-
160+
assert (
161+
call.iterations.at(1).inputs.messages[0]["content"]._source
162+
== pydantic.COMPILED_PROMPT_FULL_REASK_1
163+
)
164+
assert (
165+
call.iterations.at(1).inputs.messages[1]["content"]._source
166+
== pydantic.COMPILED_INSTRUCTIONS_CHAT
167+
)
159168
assert call.iterations.at(1).raw_output == pydantic.LLM_OUTPUT_FULL_REASK_1
160169
assert (
161170
call.iterations.at(1).validation_response == pydantic.VALIDATED_OUTPUT_REASK_2

0 commit comments

Comments
 (0)