Skip to content

Commit af852d0

Browse files
authored
Fix unittests (#108)
1 parent 8b39578 commit af852d0

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

tests/test_properties.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
)
2323
from tests.test_prompt import is_tlm_response
2424

25+
QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES = ["base", "low", "medium"]
26+
2527
test_prompt_single = make_text_unique(TEST_PROMPT)
2628
test_prompt_batch = [make_text_unique(prompt) for prompt in TEST_PROMPT_BATCH]
2729

@@ -53,13 +55,23 @@ def _test_log_batch(responses: list[dict[str, Any]], options: dict[str, Any]) ->
5355
def _is_valid_prompt_response(
5456
response: dict[str, Any],
5557
options: dict[str, Any],
58+
quality_preset: str,
59+
model: str,
5660
allow_none_response: bool = False,
5761
allow_null_trustworthiness_score: bool = False,
5862
) -> bool:
5963
"""Returns true if prompt response is valid based on properties for prompt() functionality."""
6064
_test_log(response, options)
61-
if {"num_self_reflections", "num_consistency_samples"}.issubset(options) and (
62-
options["num_consistency_samples"] == 0 and options["num_self_reflections"] == 0
65+
if (
66+
{"num_self_reflections", "num_consistency_samples"}.issubset(options)
67+
and options["num_consistency_samples"] == 0
68+
and options["num_self_reflections"] == 0
69+
) or (
70+
{"num_self_reflections"}.issubset(options)
71+
and options["num_self_reflections"] == 0
72+
and not {"num_consistency_samples"}.issubset(options)
73+
and quality_preset in QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES
74+
and model in MODELS_WITH_NO_PERPLEXITY_SCORE
6375
):
6476
print("Options dictinary called with strange parameters. Allowing none in response.")
6577
return is_tlm_response(
@@ -83,13 +95,13 @@ def _is_valid_get_trustworthiness_score_response(
8395
"""Returns true if trustworthiness score is valid based on properties for get_trustworthiness_score() functionality."""
8496
assert isinstance(response, dict)
8597

86-
quality_preset_keys = {"num_self_reflections"}
87-
consistency_sample_keys = {"num_consistency_samples", "num_self_reflections"}
88-
8998
if (
90-
(quality_preset_keys.issubset(options)) and options["num_self_reflections"] == 0 and quality_preset == "base"
99+
{"num_self_reflections"}.issubset(options)
100+
and options["num_self_reflections"] == 0
101+
and not {"num_consistency_samples"}.issubset(options)
102+
and quality_preset in QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES
91103
) or (
92-
(consistency_sample_keys.issubset(options))
104+
{"num_consistency_samples", "num_self_reflections"}.issubset(options)
93105
and options["num_self_reflections"] == 0
94106
and options["num_consistency_samples"] == 0
95107
):
@@ -104,13 +116,17 @@ def _is_valid_get_trustworthiness_score_response(
104116
def _test_prompt_response(
105117
response: dict[str, Any],
106118
options: dict[str, Any],
119+
quality_preset: str,
120+
model: str,
107121
allow_none_response: bool = False,
108122
allow_null_trustworthiness_score: bool = False,
109123
) -> None:
110124
"""Property tests the responses of a prompt based on the options dictionary and returned responses."""
111125
assert _is_valid_prompt_response(
112126
response=response,
113127
options=options,
128+
quality_preset=quality_preset,
129+
model=model,
114130
allow_none_response=allow_none_response,
115131
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
116132
)
@@ -119,6 +135,8 @@ def _test_prompt_response(
119135
def _test_batch_prompt_response(
120136
responses: list[dict[str, Any]],
121137
options: dict[str, Any],
138+
quality_preset: str,
139+
model: str,
122140
allow_none_response: bool = False,
123141
allow_null_trustworthiness_score: bool = False,
124142
) -> None:
@@ -131,6 +149,8 @@ def _test_batch_prompt_response(
131149
_is_valid_prompt_response(
132150
response,
133151
options,
152+
quality_preset,
153+
model,
134154
allow_none_response=allow_none_response,
135155
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
136156
)
@@ -219,6 +239,8 @@ def test_prompt(tlm_dict: dict[str, Any], model: str, quality_preset: str) -> No
219239
_test_prompt_response(
220240
response,
221241
{},
242+
quality_preset,
243+
model,
222244
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
223245
)
224246

@@ -231,6 +253,8 @@ def test_prompt(tlm_dict: dict[str, Any], model: str, quality_preset: str) -> No
231253
_test_batch_prompt_response(
232254
responses,
233255
options,
256+
quality_preset,
257+
model,
234258
allow_none_response=True,
235259
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
236260
)
@@ -256,7 +280,13 @@ def test_prompt_async(tlm_dict: dict[str, Any], model: str, quality_preset: str)
256280
tlm_no_options_kwargs["constrain_outputs"] = TEST_CONSTRAIN_OUTPUTS_BINARY
257281
response = asyncio.run(_run_prompt_async(tlm_no_options, test_prompt_single, **tlm_no_options_kwargs))
258282
print("TLM Single Response:", response)
259-
_test_prompt_response(response, {}, allow_null_trustworthiness_score=allow_null_trustworthiness_score)
283+
_test_prompt_response(
284+
response,
285+
{},
286+
quality_preset,
287+
model,
288+
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
289+
)
260290

261291
# test prompt with batch prompt
262292
tlm_kwargs = {}
@@ -267,6 +297,8 @@ def test_prompt_async(tlm_dict: dict[str, Any], model: str, quality_preset: str)
267297
_test_batch_prompt_response(
268298
responses,
269299
options,
300+
quality_preset,
301+
model,
270302
allow_null_trustworthiness_score=allow_null_trustworthiness_score,
271303
)
272304

0 commit comments

Comments
 (0)