Skip to content

Commit bef8bd0

Browse files
committed
refactor
1 parent 09521a7 commit bef8bd0

File tree

1 file changed

+102
-65
lines changed

1 file changed

+102
-65
lines changed

autointent/generation/utterances/evolution/dspy_evolver.py

Lines changed: 102 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
"""
2-
Evolutionary strategy to augmenting utterances.
3-
"""
1+
"""Evolutionary strategy to augmenting utterances."""
42

53
import copy
64
import logging
75
import random
86
from collections import Counter
97
from pathlib import Path
10-
from typing import Any
118

129
import dspy
1310
from datasets import Dataset as HFDataset
1411
from datasets import concatenate_datasets
15-
16-
# from dspy.evaluate import CompleteAndGrounded, SemanticF1, answer_exact_match
1712
from dspy.evaluate.auto_evaluation import f1_score
1813

1914
from autointent import Dataset, Pipeline
@@ -22,7 +17,7 @@
2217
logging.basicConfig(level=logging.INFO)
2318
logger = logging.getLogger(__name__)
2419

25-
SEARCH_SPACE = [
20+
DEFAULT_SEARCH_SPACE = [
2621
{
2722
"node_type": "scoring",
2823
"target_metric": "scoring_roc_auc",
@@ -74,17 +69,16 @@ def repetition_factor(true_text: str, augmented_text: str) -> float:
7469
overlap = sum(min(true_counts[token], aug_counts[token]) for token in true_counts.keys() & aug_counts.keys())
7570
precision = overlap / len(aug_tokens)
7671
recall = overlap / len(true_tokens)
77-
if precision + recall == 0:
78-
f1 = 0.0
79-
else:
80-
f1 = 2 * precision * recall / (precision + recall)
81-
return f1
72+
return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)
8273

8374

8475
class SemanticRecallPrecision(dspy.Signature):
8576
"""
8677
Compare a system's response to the ground truth to compute its recall and precision.
78+
8779
If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
80+
81+
Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14
8882
"""
8983

9084
# Copied from dspy
@@ -97,14 +91,37 @@ class SemanticRecallPrecision(dspy.Signature):
9791

9892

9993
class AugmentSemanticF1(dspy.Module):
100-
# adapted SemanticF1
101-
def __init__(self, threshold: float = 0.66, **kwargs: Any) -> None:
94+
"""Compare a system's response to the ground truth to compute its recall and precision.
95+
96+
Adapted from https://dspy.ai/api/evaluation/SemanticF1/
97+
"""
98+
99+
def __init__(self, threshold: float = 0.66) -> None:
100+
"""
101+
Initialize the AugmentSemanticF1.
102+
103+
Args:
104+
threshold: Threshold for the boolean output.
105+
"""
102106
self.threshold = threshold
103107
self.module = dspy.ChainOfThought(SemanticRecallPrecision)
104108

105109
def forward(
106110
self, example: dspy.Example, pred: dspy.Prediction, trace: list[dspy.Prediction] | None = None
107111
) -> float | bool:
112+
"""
113+
Compute the score for the given example and prediction.
114+
115+
Uses SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
116+
117+
Args:
118+
example: Question and ground truth.
119+
pred: System response.
120+
trace: Predictions from previous iterations.
121+
122+
Returns:
123+
The final score or a boolean based on the threshold.
124+
"""
108125
# Compute base scores using the existing semantic metric.
109126
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
110127
base_score = f1_score(scores.precision, scores.recall)
@@ -119,87 +136,112 @@ def forward(
119136

120137

121138
class DSPYIncrementalUtteranceEvolver:
122-
"""Incremental evolutionary strategy to augmenting utterances using DSPy."""
139+
"""Incremental evolutionary strategy to augmenting utterances using DSPy.
140+
141+
Implements an evolutionary strategy to augment utterances using DSPy. This module would augment the utterances.
142+
For ground truth utterances, it would generate new utterances and evaluate them using the pipeline.
143+
144+
For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
145+
"""
123146

124147
def __init__(
125148
self,
126-
seed: int = 0,
149+
model: str,
150+
api_base: str | None = None,
151+
temperature: float = 0.0,
152+
max_tokens: int = 1000,
153+
seed: int = 42,
127154
search_space: str | None = None,
128155
) -> None:
129-
"""Initialize."""
130-
self.search_space = self._choose_search_space(search_space)
156+
"""
157+
Initialize the DSPYIncrementalUtteranceEvolver.
158+
159+
Args:
160+
model: Model name. This should follow naming schema from litellm.
161+
https://docs.litellm.ai/docs/providers
162+
api_base: API base URL. Some models require this.
163+
temperature: Sampling temperature. 0.0 is default from dspy LM.
164+
max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
165+
seed: Random seed for reproducibility.
166+
search_space: Search space for the pipeline.
167+
"""
168+
self.search_space = search_space or DEFAULT_SEARCH_SPACE
131169
random.seed(seed)
132170

133-
turbo = dspy.LM(
134-
"hosted_vllm/x5-airun-medium-coder-prod",
135-
api_base="http://mn-rtx01.x5.ru:8000/v1",
136-
# api_key="test",
171+
llm = dspy.LM(
172+
model,
173+
api_base=api_base,
137174
model_type="chat",
175+
temperature=temperature,
176+
max_tokens=max_tokens,
138177
)
139-
dspy.settings.configure(lm=turbo)
140-
# self.generator = dspy.ChainOfThought("text, n_examples -> augmented_texts: list[str]")
178+
dspy.settings.configure(lm=llm)
141179
# input should be question and response is augmented. question and response required for metric
142180
self.generator = dspy.ChainOfThought("question -> response: str")
143181

144-
def _choose_search_space(self, search_space: str | None) -> list[dict[str, Any]] | Path | str:
145-
if search_space is None:
146-
return SEARCH_SPACE
147-
return search_space
148-
149182
def augment(
150183
self,
151184
dataset: Dataset,
152185
split_name: str = Split.TEST,
153-
n_evolutions: int = 1,
186+
n_evolutions: int = 3,
154187
update_split: bool = True,
155-
batch_size: int = 4,
188+
mipro_init_params: dict | None = None,
189+
mipro_compile_params: dict | None = None,
190+
save_path: Path | str = "evolution_config",
156191
) -> HFDataset:
157192
"""
158-
Augment dataset split using DSPy with incremental optimization.
193+
Augment the dataset using the evolutionary strategy.
194+
195+
Args:
196+
dataset: The dataset to augment.
197+
split_name: The name of the split to augment.
198+
n_evolutions: Number of evolutions to perform.
199+
update_split: Whether to update the split with the augmented data.
200+
mipro_init_params: Parameters for the MIPROv2 augmentation.
201+
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters
202+
mipro_compile_params: Parameters for the MIPROv2 compilation.
203+
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters
204+
save_path: Path to save the generated samples. Defaults to "evolution_config".
205+
206+
Returns:
207+
The augmented dataset.
159208
"""
160209
best_result = 0
161210
merge_dataset = copy.deepcopy(dataset)
162211
generated_samples = []
163212
original_split = dataset[split_name]
213+
if mipro_init_params is None:
214+
mipro_init_params = {}
215+
if mipro_compile_params is None:
216+
mipro_compile_params = {}
217+
218+
if isinstance(save_path, str):
219+
save_path = Path(save_path)
220+
221+
if not save_path.exists():
222+
save_path.mkdir(parents=True)
164223

165224
dspy_dataset = [
166225
dspy.Example(
167226
question=sample[Dataset.utterance_feature],
168-
# n_examples=1,
169227
response=sample[Dataset.utterance_feature], # Use original as reference
170228
).with_inputs(
171229
"question",
172-
# "n_examples"
173230
)
174231
for sample in original_split
175232
]
176233

177234
for i in range(n_evolutions):
178235
metric = AugmentSemanticF1()
179236

180-
optimizer = dspy.MIPROv2(
181-
metric=metric, # SemanticF1
182-
# auto="medium", # can be low, medium, high. this setting will override params in compile
183-
# num_threads=batch_size,
184-
# log_dir="logs",
185-
)
186-
optimized_module = optimizer.compile(
187-
self.generator,
188-
trainset=dspy_dataset,
189-
requires_permission_to_run=False,
190-
minibatch=False,
191-
# max_bootstrapped_demos=4,
192-
# max_labeled_demos=4,
193-
num_trials=5,
237+
optimizer = dspy.MIPROv2(metric=metric, **mipro_init_params)
238+
239+
optimized_module = optimizer.compile(self.generator, trainset=dspy_dataset, **mipro_compile_params)
240+
241+
optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True)
242+
optimized_module.save(
243+
(save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False
194244
)
195-
# evaluate(optimized_module)
196-
# try:
197-
self.generator.save("generator/", save_program=True)
198-
# should be dir + file *.json or *.pkl
199-
self.generator.save("generator/generator_state.json", save_program=False)
200-
201-
optimized_module.save("optimized_module", save_program=True)
202-
optimized_module.save("optimized_module/optimized_module.json", save_program=False)
203245
# Generate new samples
204246
new_samples = []
205247
for sample in original_split:
@@ -219,22 +261,17 @@ def augment(
219261
ctx = pipeline_optimizer.fit(merge_dataset)
220262
results = ctx.optimization_info.dump_evaluation_results()
221263
decision_metric = results["metrics"]["decision"][0]
264+
msg = f"Evolution {i} decision metric: {decision_metric}"
265+
logger.info(msg)
222266

223267
if decision_metric > best_result:
224268
best_result = decision_metric
269+
msg = f"Evolution {i} is the best so far."
270+
logger.info(msg)
225271
else:
226272
break
227273

228274
if update_split:
229275
dataset[split_name] = merge_dataset[split_name]
230276

231277
return concatenate_datasets(generated_samples)
232-
233-
234-
if __name__ == "__main__":
235-
from autointent import Dataset
236-
237-
# Example usage
238-
dataset = Dataset.from_hub("AutoIntent/clinc150_subset")
239-
evolver = DSPYIncrementalUtteranceEvolver(seed=42, search_space=None)
240-
augmented_dataset = evolver.augment(dataset, split_name=Split.TEST, n_evolutions=2)

0 commit comments

Comments
 (0)