Skip to content

Commit 2cbaedb

Browse files
committed
renaming predictors AND prompt creation
1 parent 5272703 commit 2cbaedb

File tree

15 files changed

+146
-101
lines changed

15 files changed

+146
-101
lines changed

promptolution/helpers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from promptolution.optimizers.evoprompt_de import EvoPromptDE
2929
from promptolution.optimizers.evoprompt_ga import EvoPromptGA
3030
from promptolution.optimizers.opro import OPRO
31-
from promptolution.predictors.classifier import FirstOccurrenceClassifier, MarkerBasedClassifier
31+
from promptolution.predictors.first_occurence_predictor import FirstOccurrencePredictor
32+
from promptolution.predictors.maker_based_predictor import MarkerBasedPredictor
3233
from promptolution.tasks.classification_tasks import ClassificationTask
3334
from promptolution.utils.logging import get_logger
3435
from promptolution.utils.templates import (
@@ -272,23 +273,23 @@ def get_predictor(downstream_llm=None, type: "PredictorType" = "marker", *args,
272273
"""Factory function to create and return a predictor instance.
273274
274275
This function supports three types of predictors:
275-
1. FirstOccurrenceClassifier: A predictor that classifies based on first occurrence of the label.
276-
2. MarkerBasedClassifier: A predictor that classifies based on a marker.
276+
1. FirstOccurrencePredictor: A predictor that classifies based on first occurrence of the label.
277+
2. MarkerBasedPredictor: A predictor that classifies based on a marker.
277278
278279
Args:
279280
downstream_llm: The language model to use for prediction.
280281
type (Literal["first_occurrence", "marker"]): The type of predictor to create:
281-
- "first_occurrence" for FirstOccurrenceClassifier
282-
- "marker" (default) for MarkerBasedClassifier
282+
- "first_occurrence" for FirstOccurrencePredictor
283+
- "marker" (default) for MarkerBasedPredictor
283284
*args: Variable length argument list passed to the predictor constructor.
284285
**kwargs: Arbitrary keyword arguments passed to the predictor constructor.
285286
286287
Returns:
287-
An instance of FirstOccurrenceClassifier or MarkerBasedClassifier.
288+
An instance of FirstOccurrencePredictor or MarkerBasedPredictor.
288289
"""
289290
if type == "first_occurrence":
290-
return FirstOccurrenceClassifier(downstream_llm, *args, **kwargs)
291+
return FirstOccurrencePredictor(downstream_llm, *args, **kwargs)
291292
elif type == "marker":
292-
return MarkerBasedClassifier(downstream_llm, *args, **kwargs)
293+
return MarkerBasedPredictor(downstream_llm, *args, **kwargs)
293294
else:
294295
raise ValueError(f"Invalid predictor type: '{type}'")

promptolution/optimizers/capo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
1010

11-
from promptolution.utils.formatting import extract_from_tag
12-
1311
if TYPE_CHECKING: # pragma: no cover
1412
from promptolution.utils.callbacks import BaseCallback
1513
from promptolution.llms.base_llm import BaseLLM
@@ -19,6 +17,7 @@
1917
from promptolution.utils.test_statistics import TestStatistics
2018

2119
from promptolution.optimizers.base_optimizer import BaseOptimizer
20+
from promptolution.utils.formatting import extract_from_tag
2221
from promptolution.utils.logging import get_logger
2322
from promptolution.utils.prompt import Prompt, sort_prompts_by_scores
2423
from promptolution.utils.templates import CAPO_CROSSOVER_TEMPLATE, CAPO_FEWSHOT_TEMPLATE, CAPO_MUTATION_TEMPLATE
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module for LLM predictors."""
22

33

4-
from promptolution.predictors.classifier import FirstOccurrenceClassifier, MarkerBasedClassifier
4+
from promptolution.predictors.first_occurence_predictor import FirstOccurrencePredictor
5+
from promptolution.predictors.maker_based_predictor import MarkerBasedPredictor

promptolution/predictors/base_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
if TYPE_CHECKING: # pragma: no cover
1111
from promptolution.utils.config import ExperimentConfig
1212

13+
1314
PredictorType = Literal["first_occurrence", "marker"]
1415

1516

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Module for the FirstOccurencePredictor."""
2+
3+
from typing import TYPE_CHECKING, List, Optional
4+
5+
from promptolution.predictors.base_predictor import BasePredictor
6+
7+
if TYPE_CHECKING: # pragma: no cover
8+
from promptolution.llms.base_llm import BaseLLM
9+
from promptolution.utils.config import ExperimentConfig
10+
11+
12+
class FirstOccurrencePredictor(BasePredictor):
13+
"""A predictor class for classification tasks using language models.
14+
15+
This class takes a language model and a list of classes, and provides a method
16+
to predict classes for given prompts and input data. The class labels are extracted
17+
by matching the words in the prediction with the list of valid class labels.
18+
The first occurrence of a valid class label in the prediction is used as the predicted class.
19+
If no valid class label is found, the first class label in the list is used as the default prediction.
20+
21+
Attributes:
22+
llm: The language model used for generating predictions.
23+
classes (List[str]): The list of valid class labels.
24+
config (ExperimentConfig, optional): Configuration for the classifier, overriding defaults.
25+
26+
Inherits from:
27+
BasePredictor: The base class for predictors in the promptolution library.
28+
"""
29+
30+
def __init__(self, llm: "BaseLLM", classes: List[str], config: Optional["ExperimentConfig"] = None) -> None:
31+
"""Initialize the FirstOccurrencePredictor.
32+
33+
Args:
34+
llm: The language model to use for predictions.
35+
classes (List[str]): The list of valid class labels.
36+
config (ExperimentConfig, optional): Configuration for the classifier, overriding defaults.
37+
"""
38+
assert all([c.islower() for c in classes]), "Class labels should be lowercase."
39+
self.classes = classes
40+
41+
self.extraction_description = (
42+
f"The task is to classify the texts into one of those classes: {', '.join(classes)}."
43+
"The first occurrence of a valid class label in the prediction is used as the predicted class."
44+
)
45+
46+
super().__init__(llm, config)
47+
48+
def _extract_preds(self, preds: List[str]) -> List[str]:
49+
"""Extract class labels from the predictions, based on the list of valid class labels.
50+
51+
Args:
52+
preds: The raw predictions from the language model.
53+
"""
54+
result = []
55+
for pred in preds:
56+
predicted_class = self.classes[0] # use first class as default pred
57+
for word in pred.split():
58+
word = "".join([c for c in word if c.isalnum()]).lower()
59+
if word in self.classes:
60+
predicted_class = word
61+
break
62+
63+
result.append(predicted_class)
64+
65+
return result

promptolution/predictors/classifier.py renamed to promptolution/predictors/maker_based_predictor.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
"""Module for classification predictors."""
1+
"""Module for the MarkerBasedPredictor."""
22

3-
4-
import numpy as np
5-
6-
from typing import TYPE_CHECKING, Any, List, Optional
3+
from typing import TYPE_CHECKING, List, Optional
74

85
from promptolution.predictors.base_predictor import BasePredictor
96
from promptolution.utils.formatting import extract_from_tag
@@ -13,64 +10,8 @@
1310
from promptolution.utils.config import ExperimentConfig
1411

1512

16-
class FirstOccurrenceClassifier(BasePredictor):
17-
"""A predictor class for classification tasks using language models.
18-
19-
This class takes a language model and a list of classes, and provides a method
20-
to predict classes for given prompts and input data. The class labels are extracted
21-
by matching the words in the prediction with the list of valid class labels.
22-
The first occurrence of a valid class label in the prediction is used as the predicted class.
23-
If no valid class label is found, the first class label in the list is used as the default prediction.
24-
25-
Attributes:
26-
llm: The language model used for generating predictions.
27-
classes (List[str]): The list of valid class labels.
28-
config (ExperimentConfig, optional): Configuration for the classifier, overriding defaults.
29-
30-
Inherits from:
31-
BasePredictor: The base class for predictors in the promptolution library.
32-
"""
33-
34-
def __init__(self, llm: "BaseLLM", classes: List[str], config: Optional["ExperimentConfig"] = None) -> None:
35-
"""Initialize the FirstOccurrenceClassifier.
36-
37-
Args:
38-
llm: The language model to use for predictions.
39-
classes (List[str]): The list of valid class labels.
40-
config (ExperimentConfig, optional): Configuration for the classifier, overriding defaults.
41-
"""
42-
assert all([c.islower() for c in classes]), "Class labels should be lowercase."
43-
self.classes = classes
44-
45-
self.extraction_description = (
46-
f"The task is to classify the texts into one of those classes: {', '.join(classes)}."
47-
"The first occurrence of a valid class label in the prediction is used as the predicted class."
48-
)
49-
50-
super().__init__(llm, config)
51-
52-
def _extract_preds(self, preds: List[str]) -> List[str]:
53-
"""Extract class labels from the predictions, based on the list of valid class labels.
54-
55-
Args:
56-
preds: The raw predictions from the language model.
57-
"""
58-
result = []
59-
for pred in preds:
60-
predicted_class = self.classes[0] # use first class as default pred
61-
for word in pred.split():
62-
word = "".join([c for c in word if c.isalnum()]).lower()
63-
if word in self.classes:
64-
predicted_class = word
65-
break
66-
67-
result.append(predicted_class)
68-
69-
return result
70-
71-
72-
class MarkerBasedClassifier(BasePredictor):
73-
"""A predictor class for classification tasks using language models.
13+
class MarkerBasedPredictor(BasePredictor):
14+
"""A predictor class task using language models.
7415
7516
This class takes a language model and a list of classes, and provides a method
7617
to predict classes for given prompts and input data. The class labels are extracted.
@@ -92,7 +33,7 @@ def __init__(
9233
end_marker: str = "</final_answer>",
9334
config: Optional["ExperimentConfig"] = None,
9435
) -> None:
95-
"""Initialize the MarkerBasedClassifier.
36+
"""Initialize the MarkerBasedPredictor.
9637
9738
Args:
9839
llm: The language model to use for predictions.

promptolution/tasks/judge_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _evaluate(self, xs: List[str], ys: List[str], preds: List[str]) -> List[floa
132132
judge_responses = self.judge_llm.get_response(prompts)
133133
scores_str = extract_from_tag(judge_responses, "<final_score>", "</final_score>")
134134
scores = []
135-
for score_str, judge_response in zip(scores_str, judge_responses):
135+
for score_str in scores_str:
136136
try:
137137
# only numeric chars, - or . are allowed
138138
score_str = "".join(filter(lambda c: c.isdigit() or c in "-.", score_str))

promptolution/utils/prompt_creation.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from promptolution.tasks.classification_tasks import ClassificationTask
1515
from promptolution.utils.templates import (
1616
PROMPT_CREATION_TEMPLATE,
17+
PROMPT_CREATION_TEMPLATE_FROM_TASK_DESCRIPTION,
1718
PROMPT_CREATION_TEMPLATE_TD,
1819
PROMPT_VARIATION_TEMPLATE,
1920
)
@@ -50,7 +51,7 @@ def create_prompts_from_samples(
5051
llm: "BaseLLM",
5152
meta_prompt: Optional[str] = None,
5253
n_samples: int = 3,
53-
task_description: Optional[str] = None,
54+
task_description: str = None,
5455
n_prompts: int = 1,
5556
get_uniform_labels: bool = False,
5657
) -> List[str]:
@@ -119,3 +120,33 @@ def create_prompts_from_samples(
119120
prompts = extract_from_tag(prompts, "<prompt>", "</prompt>")
120121

121122
return prompts
123+
124+
125+
def create_prompts_from_task_description(
126+
task_description: str,
127+
llm: "BaseLLM",
128+
meta_prompt: Optional[str] = None,
129+
n_prompts: int = 1,
130+
) -> List[str]:
131+
"""Generate a set of prompts from a given task description.
132+
133+
Args:
134+
task_description (str): The description of the task to generate prompts for.
135+
llm (BaseLLM): The language model to use for generating the prompts.
136+
meta_prompt (str): The meta prompt to use for generating the prompts.
137+
If None, a default meta prompt is used.
138+
n_prompts (int): The number of prompts to generate.
139+
140+
Returns:
141+
List[str]: A list of generated prompts.
142+
"""
143+
if meta_prompt is None:
144+
meta_prompt = PROMPT_CREATION_TEMPLATE_FROM_TASK_DESCRIPTION
145+
146+
meta_prompt = meta_prompt.replace("<task_desc>", task_description)
147+
148+
meta_prompts = [meta_prompt for _ in range(n_prompts)]
149+
prompts = llm.get_response(meta_prompts)
150+
prompts = extract_from_tag(prompts, "<prompt>", "</prompt>")
151+
152+
return prompts

promptolution/utils/templates.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@
138138
139139
The instruction was"""
140140

141+
PROMPT_CREATION_TEMPLATE_FROM_TASK_DESCRIPTION = """Please create a prompt for the following task, not using any placeholders, working universally, for any datapoint-specific instructions following each system prompt.
142+
143+
Task: <task_desc>
144+
145+
Explicitly state this expected format as part of the prompts."""
146+
141147

142148
DOWNSTREAM_TEMPLATE = "<instruction>"
143149

tests/predictors/test_classifiers.py renamed to tests/predictors/test_predictors.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
22
import pytest
33

4-
from promptolution.helpers import FirstOccurrenceClassifier, MarkerBasedClassifier
4+
from promptolution.helpers import FirstOccurrencePredictor, MarkerBasedPredictor
55

66

77
def test_first_occurrence_classifier(mock_downstream_llm, mock_df):
8-
"""Test the FirstOccurrenceClassifier."""
8+
"""Test the FirstOccurrencePredictor."""
99
# Create classifier
10-
classifier = FirstOccurrenceClassifier(llm=mock_downstream_llm, classes=mock_df["y"].values)
10+
classifier = FirstOccurrencePredictor(llm=mock_downstream_llm, classes=mock_df["y"].values)
1111

1212
# Test with multiple inputs
1313
xs = ["I love this product!", "I hate this product!", "This product is okay.", "ja ne"]
@@ -25,9 +25,9 @@ def test_first_occurrence_classifier(mock_downstream_llm, mock_df):
2525

2626

2727
def test_marker_based_classifier(mock_downstream_llm, mock_df):
28-
"""Test the MarkerBasedClassifier."""
28+
"""Test the MarkerBasedPredictor."""
2929
# Create classifier
30-
classifier = MarkerBasedClassifier(
30+
classifier = MarkerBasedPredictor(
3131
llm=mock_downstream_llm,
3232
classes=mock_df["y"].values,
3333
begin_marker="<final_answer>",
@@ -56,9 +56,9 @@ def test_marker_based_classifier(mock_downstream_llm, mock_df):
5656

5757

5858
def test_marker_based_without_classes(mock_downstream_llm):
59-
"""Test MarkerBasedClassifier without predefined classes."""
59+
"""Test MarkerBasedPredictor without predefined classes."""
6060
# Create classifier without classes
61-
classifier = MarkerBasedClassifier(
61+
predictor = MarkerBasedPredictor(
6262
llm=mock_downstream_llm,
6363
classes=None, # No class restrictions
6464
begin_marker="<final_answer>",
@@ -70,7 +70,7 @@ def test_marker_based_without_classes(mock_downstream_llm):
7070
prompts = ["Classify:"] * len(xs)
7171

7272
# Make predictions
73-
predictions = classifier.predict(prompts, xs)
73+
predictions = predictor.predict(prompts, xs)
7474

7575
# Verify shape and content - should accept any value between markers
7676
assert len(predictions) == 4
@@ -83,7 +83,7 @@ def test_marker_based_without_classes(mock_downstream_llm):
8383
def test_multiple_prompts_with_classifiers(mock_downstream_llm, mock_df):
8484
"""Test using multiple prompts with classifiers."""
8585
# Create classifier
86-
classifier = FirstOccurrenceClassifier(llm=mock_downstream_llm, classes=mock_df["y"].values)
86+
classifier = FirstOccurrencePredictor(llm=mock_downstream_llm, classes=mock_df["y"].values)
8787

8888
# Test with multiple prompts
8989
prompts = ["Classify:", "Classify:", "Rate:", "Rate:"]
@@ -103,7 +103,7 @@ def test_multiple_prompts_with_classifiers(mock_downstream_llm, mock_df):
103103
def test_sequence_return_with_classifiers(mock_downstream_llm, mock_df):
104104
"""Test return_seq parameter with classifiers."""
105105
# Create classifier
106-
classifier = MarkerBasedClassifier(llm=mock_downstream_llm, classes=mock_df["y"].values)
106+
classifier = MarkerBasedPredictor(llm=mock_downstream_llm, classes=mock_df["y"].values)
107107

108108
# Test with return_seq=True
109109
prompts = ["Classify:"]
@@ -128,15 +128,15 @@ def test_invalid_class_labels(mock_downstream_llm):
128128

129129
# Should raise an assertion error
130130
with pytest.raises(AssertionError):
131-
FirstOccurrenceClassifier(llm=mock_downstream_llm, classes=invalid_classes)
131+
FirstOccurrencePredictor(llm=mock_downstream_llm, classes=invalid_classes)
132132

133133
with pytest.raises(AssertionError):
134-
MarkerBasedClassifier(llm=mock_downstream_llm, classes=invalid_classes)
134+
MarkerBasedPredictor(llm=mock_downstream_llm, classes=invalid_classes)
135135

136136

137137
def test_marker_based_missing_markers(mock_downstream_llm):
138-
"""Test MarkerBasedClassifier behavior when markers are missing."""
139-
classifier = MarkerBasedClassifier(llm=mock_downstream_llm, classes=["will", "not", "be", "used"])
138+
"""Test MarkerBasedPredictor behavior when markers are missing."""
139+
classifier = MarkerBasedPredictor(llm=mock_downstream_llm, classes=["will", "not", "be", "used"])
140140

141141
# When markers are missing, it should default to first class
142142
prompts = ["Classify:"]

0 commit comments

Comments
 (0)