Skip to content

Commit 484844a

Browse files
committed
fix tests
1 parent 50e2899 commit 484844a

File tree

5 files changed

+59
-7
lines changed

5 files changed

+59
-7
lines changed

promptolution/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ def run_evaluation(
122122
if isinstance(prompts[0], str):
123123
str_prompts = cast(List[str], prompts)
124124
prompts = [Prompt(p) for p in str_prompts]
125+
else:
126+
str_prompts = [p.construct_prompt() for p in cast(List[Prompt], prompts)]
125127
scores = task.evaluate(prompts, predictor, eval_strategy="full")
126-
df = pd.DataFrame(dict(prompt=prompts, score=scores))
128+
df = pd.DataFrame(dict(prompt=str_prompts, score=scores))
127129
df = df.sort_values("score", ascending=False, ignore_index=True)
128130

129131
return df

promptolution/utils/prompt_creation.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, List, Optional, Union
88

99
from promptolution.utils.formatting import extract_from_tag
10+
from promptolution.utils.logging import get_logger
1011

1112
if TYPE_CHECKING: # pragma: no cover
1213
from promptolution.llms.base_llm import BaseLLM
@@ -18,8 +19,11 @@
1819
PROMPT_CREATION_TEMPLATE_FROM_TASK_DESCRIPTION,
1920
PROMPT_CREATION_TEMPLATE_TD,
2021
PROMPT_VARIATION_TEMPLATE,
22+
default_prompts,
2123
)
2224

25+
logger = get_logger(__name__)
26+
2327

2428
def create_prompt_variation(
2529
prompt: Union[List[str], str], llm: "BaseLLM", meta_prompt: Optional[str] = None
@@ -128,6 +132,7 @@ def create_prompts_from_task_description(
128132
llm: "BaseLLM",
129133
meta_prompt: Optional[str] = None,
130134
n_prompts: int = 10,
135+
n_retries: int = 3,
131136
) -> List[str]:
132137
"""Generate a set of prompts from a given task description.
133138
@@ -137,13 +142,27 @@ def create_prompts_from_task_description(
137142
meta_prompt (str): The meta prompt to use for generating the prompts.
138143
If None, a default meta prompt is used.
139144
n_prompts (int): The number of prompts to generate.
145+
n_retries (int): The number of retries to attempt if prompt generation fails.
140146
"""
141147
if meta_prompt is None:
142148
meta_prompt = PROMPT_CREATION_TEMPLATE_FROM_TASK_DESCRIPTION
143149

144150
meta_prompt = meta_prompt.replace("<task_desc>", task_description).replace("<n_prompts>", str(n_prompts))
145-
146-
prompts_str = llm.get_response(meta_prompt)[0]
147-
prompts = json.loads(prompts_str)
148-
149-
return prompts
151+
final_prompts = None
152+
for _ in range(n_retries):
153+
prompts_str = llm.get_response(meta_prompt)[0]
154+
try:
155+
prompts = json.loads(prompts_str)
156+
assert isinstance(prompts, list) and all(isinstance(p, str) for p in prompts) and len(prompts) == n_prompts
157+
final_prompts = prompts
158+
break
159+
except (json.JSONDecodeError, AssertionError):
160+
logger.warning("Failed to parse prompts JSON, retrying...")
161+
162+
if final_prompts is None:
163+
logger.error(
164+
f"Failed to generate prompts from task description after {n_retries} retries, returning default prompts."
165+
)
166+
final_prompts = default_prompts[:n_prompts]
167+
168+
return final_prompts

promptolution/utils/templates.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,31 @@
174174
175175
Return the new prompt in the following format:
176176
<prompt>new prompt</prompt>"""
177+
178+
179+
default_prompts = [
180+
"Give me your response within <final_answer> tags.",
181+
"Please provide a thoughtful answer to my question and wrap your response in <final_answer> tags so I can easily identify it.",
182+
"I need your expertise on this matter. Kindly structure your response within <final_answer> tags for better readability.",
183+
"Analyze the following and present your findings enclosed in <final_answer> </final_answer> tags.",
184+
"Consider this inquiry carefully. Your comprehensive response should be formatted within <final_answer> tags to facilitate extraction.",
185+
"Respond succinctly. Ensure all content appears between <final_answer> and </final_answer> markers.",
186+
"Would you mind addressing this request? Please place your entire response inside <final_answer> </final_answer> formatting.",
187+
"I'm seeking your insights on a particular topic. Kindly ensure that your complete analysis is contained within <final_answer> tags for my convenience.",
188+
"Examine this query thoroughly and deliver your conclusions. All output must be encapsulated in <final_answer> </final_answer> notation for processing purposes.",
189+
"Help me understand this subject better. Your explanation should begin with <final_answer> and conclude with </final_answer> to maintain proper structure.",
190+
"I require information on the following. Please format your response with <final_answer> tags at the beginning and end for clarity.",
191+
"Contemplate this scenario and offer your perspective. Remember to enclose all content within <final_answer> tags as per requirements.",
192+
"Elaborate on this concept, making sure to wrap the entirety of your explanation in <final_answer> </final_answer> markers for systematic review.",
193+
"Describe your approach to this situation. Be thorough yet concise, and place your complete response between <final_answer> and </final_answer> tags.",
194+
"Share your knowledge on this matter. Your entire response should be presented within <final_answer> tags to facilitate proper integration into my workflow.",
195+
"Let's think step by step. Your answer should be enclosed within <final_answer> </final_answer> tags.",
196+
"Provide a detailed response to the following question, ensuring that all information is contained within <final_answer> tags for easy extraction.",
197+
"Kindly address the following topic, formatting your entire response between <final_answer> and </final_answer> markers for clarity and organization.",
198+
"Offer your insights on this issue, making sure to encapsulate your full response within <final_answer> tags for seamless processing.",
199+
"Delve into this subject and present your findings, ensuring that all content is wrapped in <final_answer> </final_answer> notation for systematic analysis.",
200+
"Illuminate this topic with your expertise, formatting your complete explanation within <final_answer> tags for straightforward comprehension.",
201+
"Provide your perspective on this matter, ensuring that your entire response is contained within <final_answer> tags for efficient review.",
202+
"Analyze the following scenario and deliver your conclusions, making sure to enclose all output in <final_answer> </final_answer> markers for clarity.",
203+
"Help me grasp this concept better by structuring your explanation between <final_answer> and </final_answer> tags for proper formatting.",
204+
]

tests/helpers/test_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def test_run_evaluation(mock_get_task, mock_get_predictor, mock_get_llm, sample_
197197
"Is this text positive, negative, or neutral?",
198198
]
199199

200+
prompts = [Prompt(p) for p in prompts]
201+
200202
# Now this will work because mock_task is a MagicMock
201203
mock_task.evaluate.return_value = np.array([0.8, 0.7, 0.9])
202204

@@ -298,6 +300,7 @@ def test_helpers_integration(sample_df, experiment_config):
298300
# Verify results
299301
assert isinstance(result, pd.DataFrame)
300302
assert len(result) == 2
303+
print([p in result["prompt"].values for p in optimized_prompts_str])
301304
assert all(p in result["prompt"].values for p in optimized_prompts_str)
302305

303306
# Verify optimization was called

tests/optimizers/test_capo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_capo_crossover_prompt(mock_meta_llm, mock_predictor, initial_prompts, m
209209
.replace("<task_desc>", full_task_desc)
210210
)
211211

212-
assert mock_meta_llm.call_history[0]["prompts"][0] == expected_meta_prompt
212+
assert str(mock_meta_llm.call_history[0]["prompts"][0]) == expected_meta_prompt
213213

214214

215215
def test_capo_mutate_prompt(mock_meta_llm, mock_predictor, initial_prompts, mock_task, mock_df):

0 commit comments

Comments
 (0)