2222)
2323from tests .test_prompt import is_tlm_response
2424
25+ QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES = ["base" , "low" , "medium" ]
26+
2527test_prompt_single = make_text_unique (TEST_PROMPT )
2628test_prompt_batch = [make_text_unique (prompt ) for prompt in TEST_PROMPT_BATCH ]
2729
@@ -53,13 +55,23 @@ def _test_log_batch(responses: list[dict[str, Any]], options: dict[str, Any]) ->
5355def _is_valid_prompt_response (
5456 response : dict [str , Any ],
5557 options : dict [str , Any ],
58+ quality_preset : str ,
59+ model : str ,
5660 allow_none_response : bool = False ,
5761 allow_null_trustworthiness_score : bool = False ,
5862) -> bool :
5963 """Returns true if prompt response is valid based on properties for prompt() functionality."""
6064 _test_log (response , options )
61- if {"num_self_reflections" , "num_consistency_samples" }.issubset (options ) and (
62- options ["num_consistency_samples" ] == 0 and options ["num_self_reflections" ] == 0
65+ if (
66+ {"num_self_reflections" , "num_consistency_samples" }.issubset (options )
67+ and options ["num_consistency_samples" ] == 0
68+ and options ["num_self_reflections" ] == 0
69+ ) or (
70+ {"num_self_reflections" }.issubset (options )
71+ and options ["num_self_reflections" ] == 0
72+ and not {"num_consistency_samples" }.issubset (options )
73+ and quality_preset in QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES
74+ and model in MODELS_WITH_NO_PERPLEXITY_SCORE
6375 ):
6476 print ("Options dictinary called with strange parameters. Allowing none in response." )
6577 return is_tlm_response (
@@ -83,13 +95,13 @@ def _is_valid_get_trustworthiness_score_response(
8395 """Returns true if trustworthiness score is valid based on properties for get_trustworthiness_score() functionality."""
8496 assert isinstance (response , dict )
8597
86- quality_preset_keys = {"num_self_reflections" }
87- consistency_sample_keys = {"num_consistency_samples" , "num_self_reflections" }
88-
8998 if (
90- (quality_preset_keys .issubset (options )) and options ["num_self_reflections" ] == 0 and quality_preset == "base"
99+ {"num_self_reflections" }.issubset (options )
100+ and options ["num_self_reflections" ] == 0
101+ and not {"num_consistency_samples" }.issubset (options )
102+ and quality_preset in QUALITY_PRESETS_WITH_NO_CONSISTENCY_SAMPLES
91103 ) or (
92- ( consistency_sample_keys .issubset (options ) )
104+ { "num_consistency_samples" , "num_self_reflections" } .issubset (options )
93105 and options ["num_self_reflections" ] == 0
94106 and options ["num_consistency_samples" ] == 0
95107 ):
@@ -104,13 +116,17 @@ def _is_valid_get_trustworthiness_score_response(
104116def _test_prompt_response (
105117 response : dict [str , Any ],
106118 options : dict [str , Any ],
119+ quality_preset : str ,
120+ model : str ,
107121 allow_none_response : bool = False ,
108122 allow_null_trustworthiness_score : bool = False ,
109123) -> None :
110124 """Property tests the responses of a prompt based on the options dictionary and returned responses."""
111125 assert _is_valid_prompt_response (
112126 response = response ,
113127 options = options ,
128+ quality_preset = quality_preset ,
129+ model = model ,
114130 allow_none_response = allow_none_response ,
115131 allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
116132 )
@@ -119,6 +135,8 @@ def _test_prompt_response(
119135def _test_batch_prompt_response (
120136 responses : list [dict [str , Any ]],
121137 options : dict [str , Any ],
138+ quality_preset : str ,
139+ model : str ,
122140 allow_none_response : bool = False ,
123141 allow_null_trustworthiness_score : bool = False ,
124142) -> None :
@@ -131,6 +149,8 @@ def _test_batch_prompt_response(
131149 _is_valid_prompt_response (
132150 response ,
133151 options ,
152+ quality_preset ,
153+ model ,
134154 allow_none_response = allow_none_response ,
135155 allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
136156 )
@@ -219,6 +239,8 @@ def test_prompt(tlm_dict: dict[str, Any], model: str, quality_preset: str) -> No
219239 _test_prompt_response (
220240 response ,
221241 {},
242+ quality_preset ,
243+ model ,
222244 allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
223245 )
224246
@@ -231,6 +253,8 @@ def test_prompt(tlm_dict: dict[str, Any], model: str, quality_preset: str) -> No
231253 _test_batch_prompt_response (
232254 responses ,
233255 options ,
256+ quality_preset ,
257+ model ,
234258 allow_none_response = True ,
235259 allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
236260 )
@@ -256,7 +280,13 @@ def test_prompt_async(tlm_dict: dict[str, Any], model: str, quality_preset: str)
256280 tlm_no_options_kwargs ["constrain_outputs" ] = TEST_CONSTRAIN_OUTPUTS_BINARY
257281 response = asyncio .run (_run_prompt_async (tlm_no_options , test_prompt_single , ** tlm_no_options_kwargs ))
258282 print ("TLM Single Response:" , response )
259- _test_prompt_response (response , {}, allow_null_trustworthiness_score = allow_null_trustworthiness_score )
283+ _test_prompt_response (
284+ response ,
285+ {},
286+ quality_preset ,
287+ model ,
288+ allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
289+ )
260290
261291 # test prompt with batch prompt
262292 tlm_kwargs = {}
@@ -267,6 +297,8 @@ def test_prompt_async(tlm_dict: dict[str, Any], model: str, quality_preset: str)
267297 _test_batch_prompt_response (
268298 responses ,
269299 options ,
300+ quality_preset ,
301+ model ,
270302 allow_null_trustworthiness_score = allow_null_trustworthiness_score ,
271303 )
272304
0 commit comments