Skip to content

Commit ae2a831

Browse files
authored
Add dspy (#151)
* update chat template structure * update chat template structure * refactor chat templates * lint * fix imports * start dspy evolver * try to generate * update * some working version * refactor * add signature * fix lint * Fix: add import error handling for dspy module * lint * fix typing
1 parent 4d2f31e commit ae2a831

File tree

2 files changed

+291
-0
lines changed

2 files changed

+291
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
"""Evolutionary strategy to augmenting utterances."""
2+
3+
import copy
4+
import logging
5+
import random
6+
from collections import Counter
7+
from pathlib import Path
8+
from typing import Any
9+
10+
try:
11+
import dspy
12+
except ImportError:
13+
import_error = "dspy is not installed. Please install it with `pip install dspy` or `pip install autointent[dspy]`."
14+
raise ImportError(import_error) from None
15+
16+
from datasets import Dataset as HFDataset
17+
from datasets import concatenate_datasets
18+
from dspy.evaluate.auto_evaluation import f1_score
19+
20+
from autointent import Dataset, Pipeline
21+
from autointent.custom_types import Split
22+
23+
logging.basicConfig(level=logging.INFO)
24+
logger = logging.getLogger(__name__)
25+
26+
DEFAULT_SEARCH_SPACE = [
27+
{
28+
"node_type": "scoring",
29+
"target_metric": "scoring_roc_auc",
30+
"metrics": ["scoring_accuracy"],
31+
"search_space": [
32+
{
33+
"module_name": "linear",
34+
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
35+
}
36+
],
37+
},
38+
{
39+
"node_type": "decision",
40+
"target_metric": "decision_accuracy",
41+
"search_space": [
42+
{"module_name": "argmax"},
43+
],
44+
},
45+
]
46+
47+
48+
def repetition_factor(true_text: str, augmented_text: str) -> float:
49+
"""Calculate the average ROUGE-1 F1 score between pairs of texts in true_texts and augmented_texts.
50+
51+
ROUGE-1 F1 is computed as:
52+
F1 = 2 * (precision * recall) / (precision + recall)
53+
where:
54+
- precision = (overlap in unigrams) / (total unigrams in augmented text)
55+
- recall = (overlap in unigrams) / (total unigrams in true text)
56+
57+
Args:
58+
true_text: A ground truth text.
59+
augmented_text: A list of augmented/generated text.
60+
61+
Returns:
62+
float: The average ROUGE-1 F1 score across all pairs.
63+
64+
Raises:
65+
ValueError: If the lengths of true_texts and augmented_texts differ.
66+
"""
67+
true_tokens = true_text.split()
68+
aug_tokens = augmented_text.split()
69+
if not true_tokens or not aug_tokens:
70+
return 0.0
71+
true_counts = Counter(true_tokens)
72+
aug_counts = Counter(aug_tokens)
73+
# Calculate the token overlap using the minimum count for common tokens
74+
overlap = sum(min(true_counts[token], aug_counts[token]) for token in true_counts.keys() & aug_counts.keys())
75+
precision = overlap / len(aug_tokens)
76+
recall = overlap / len(true_tokens)
77+
return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)
78+
79+
80+
class SemanticRecallPrecision(dspy.Signature): # type: ignore[misc]
81+
"""Compare a system's response to the ground truth to compute its recall and precision.
82+
83+
If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
84+
85+
Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14
86+
"""
87+
88+
question: str = dspy.InputField()
89+
ground_truth: str = dspy.InputField()
90+
system_response: str = dspy.InputField()
91+
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
92+
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")
93+
94+
95+
class AugmentSemanticF1(dspy.Module): # type: ignore[misc]
96+
"""Compare a system's response to the ground truth to compute its recall and precision.
97+
98+
Adapted from https://dspy.ai/api/evaluation/SemanticF1/
99+
"""
100+
101+
def __init__(self, threshold: float = 0.66) -> None:
102+
"""Initialize the AugmentSemanticF1.
103+
104+
Args:
105+
threshold: Threshold for the boolean output.
106+
"""
107+
self.threshold = threshold
108+
self.module = dspy.ChainOfThought(SemanticRecallPrecision)
109+
110+
def forward(
111+
self, example: dspy.Example, pred: dspy.Prediction, trace: list[dspy.Prediction] | None = None
112+
) -> float | bool:
113+
"""Compute the score for the given example and prediction.
114+
115+
Uses SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
116+
117+
Args:
118+
example: Question and ground truth.
119+
pred: System response.
120+
trace: Predictions from previous iterations.
121+
122+
Returns:
123+
The final score or a boolean based on the threshold.
124+
"""
125+
# Compute base scores using the existing semantic metric.
126+
scores = self.module(
127+
question=example.text, ground_truth=example.augmented_text, system_response=pred.augmented_text
128+
)
129+
base_score = f1_score(scores.precision, scores.recall)
130+
131+
# Compute repetition penalty factor.
132+
penalty = repetition_factor(example.augmented_text, pred.augmented_text)
133+
# length_penalty = len(example.augmented_text) / len(pred.augmented_text)
134+
# Apply penalty to the base score.
135+
final_score = base_score * penalty # * length_penalty
136+
# Return the final score, or a boolean based on the threshold if trace is provided.
137+
return final_score if trace is None else final_score >= self.threshold # type: ignore[no-any-return]
138+
139+
140+
class AugmentationSignature(dspy.Signature): # type: ignore[misc]
141+
"""Signature for text generation for augmentation task."""
142+
143+
text: str = dspy.InputField(desc="Text to augment. Your task to paraphrase this text.")
144+
augmented_text: str = dspy.OutputField(desc="Augmented text. This should be on same language as text")
145+
146+
147+
class DSPYIncrementalUtteranceEvolver:
148+
"""Incremental evolutionary strategy to augmenting utterances using DSPy.
149+
150+
Implements an evolutionary strategy to augment utterances using DSPy. This module would augment the utterances.
151+
For ground truth utterances, it would generate new utterances and evaluate them using the pipeline.
152+
153+
For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
154+
"""
155+
156+
def __init__(
157+
self,
158+
model: str,
159+
api_base: str | None = None,
160+
temperature: float = 0.0,
161+
max_tokens: int = 1000,
162+
seed: int = 42,
163+
search_space: str | None = None,
164+
) -> None:
165+
"""Initialize the DSPYIncrementalUtteranceEvolver.
166+
167+
Args:
168+
model: Model name. This should follow naming schema from litellm.
169+
https://docs.litellm.ai/docs/providers
170+
api_base: API base URL. Some models require this.
171+
temperature: Sampling temperature. 0.0 is default from dspy LM.
172+
max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
173+
seed: Random seed for reproducibility.
174+
search_space: Search space for the pipeline.
175+
"""
176+
self.search_space = search_space or DEFAULT_SEARCH_SPACE
177+
random.seed(seed)
178+
179+
llm = dspy.LM(
180+
model,
181+
api_base=api_base,
182+
model_type="chat",
183+
temperature=temperature,
184+
max_tokens=max_tokens,
185+
)
186+
dspy.settings.configure(lm=llm)
187+
self.generator = dspy.ChainOfThoughtWithHint(AugmentationSignature)
188+
189+
def augment(
190+
self,
191+
dataset: Dataset,
192+
split_name: str = Split.TEST,
193+
n_evolutions: int = 3,
194+
update_split: bool = True,
195+
mipro_init_params: dict[str, Any] | None = None,
196+
mipro_compile_params: dict[str, Any] | None = None,
197+
save_path: Path | str = "evolution_config",
198+
) -> HFDataset:
199+
"""Augment the dataset using the evolutionary strategy.
200+
201+
Args:
202+
dataset: The dataset to augment.
203+
split_name: The name of the split to augment.
204+
n_evolutions: Number of evolutions to perform.
205+
update_split: Whether to update the split with the augmented data.
206+
mipro_init_params: Parameters for the MIPROv2 augmentation.
207+
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters
208+
mipro_compile_params: Parameters for the MIPROv2 compilation.
209+
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters
210+
save_path: Path to save the generated samples. Defaults to "evolution_config".
211+
212+
Returns:
213+
The augmented dataset.
214+
"""
215+
best_result = 0
216+
merge_dataset = copy.deepcopy(dataset)
217+
generated_samples = []
218+
original_split = dataset[split_name]
219+
if mipro_init_params is None:
220+
mipro_init_params = {}
221+
if mipro_compile_params is None:
222+
mipro_compile_params = {}
223+
224+
if isinstance(save_path, str):
225+
save_path = Path(save_path)
226+
227+
if not save_path.exists():
228+
save_path.mkdir(parents=True)
229+
230+
dspy_dataset = [
231+
dspy.Example(
232+
text=sample[Dataset.utterance_feature],
233+
augmented_text=sample[Dataset.utterance_feature], # Use original as reference
234+
).with_inputs(
235+
"text",
236+
)
237+
for sample in original_split
238+
]
239+
240+
for i in range(n_evolutions):
241+
metric = AugmentSemanticF1()
242+
243+
optimizer = dspy.MIPROv2(metric=metric, **mipro_init_params)
244+
245+
optimized_module = optimizer.compile(self.generator, trainset=dspy_dataset, **mipro_compile_params)
246+
247+
optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True)
248+
optimized_module.save(
249+
(save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False
250+
)
251+
# Generate new samples
252+
new_samples = []
253+
for sample in original_split:
254+
utterance = sample[Dataset.utterance_feature]
255+
label = sample[Dataset.label_feature]
256+
prediction = optimized_module(text=utterance)
257+
new_samples.append({Dataset.label_feature: label, Dataset.utterance_feature: prediction.augmented_text})
258+
259+
new_samples_dataset = HFDataset.from_list(new_samples)
260+
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
261+
generated_samples.append(new_samples_dataset)
262+
263+
# Check if the new samples improve the model
264+
pipeline_optimizer = Pipeline.from_search_space(self.search_space)
265+
ctx = pipeline_optimizer.fit(merge_dataset)
266+
results = ctx.optimization_info.dump_evaluation_results()
267+
decision_metric = results["metrics"]["decision"][0]
268+
msg = f"Evolution {i} decision metric: {decision_metric}"
269+
logger.info(msg)
270+
271+
if decision_metric > best_result:
272+
best_result = decision_metric
273+
msg = f"Evolution {i} is the best so far."
274+
logger.info(msg)
275+
else:
276+
break
277+
278+
if update_split:
279+
dataset[split_name] = merge_dataset[split_name]
280+
281+
return concatenate_datasets(generated_samples)

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ ipykernel = "^6.29.5"
100100
tensorboardx = "^2.6.2.2"
101101
sphinx-multiversion = "^0.2.4"
102102

103+
[tool.poetry.group.dspy]
104+
optional = true
105+
106+
107+
[tool.poetry.group.dspy.dependencies]
108+
dspy = "^2.6.5"
109+
110+
103111
[tool.ruff]
104112
line-length = 120
105113
indent-width = 4
@@ -194,6 +202,8 @@ module = [
194202
"torch.utils.tensorboard",
195203
"tensorboardX",
196204
"wandb",
205+
"dspy",
206+
"dspy.evaluate.auto_evaluation",
197207
]
198208
ignore_missing_imports = true
199209

0 commit comments

Comments
 (0)