Skip to content
281 changes: 281 additions & 0 deletions autointent/generation/utterances/evolution/dspy_evolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
"""Evolutionary strategy to augmenting utterances."""

import copy
import logging
import random
from collections import Counter
from pathlib import Path
from typing import Any

try:
import dspy
except ImportError:
import_error = "dspy is not installed. Please install it with `pip install dspy` or `pip install autointent[dspy]`."
raise ImportError(import_error) from None

from datasets import Dataset as HFDataset
from datasets import concatenate_datasets
from dspy.evaluate.auto_evaluation import f1_score

from autointent import Dataset, Pipeline
from autointent.custom_types import Split

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

DEFAULT_SEARCH_SPACE = [
{
"node_type": "scoring",
"target_metric": "scoring_roc_auc",
"metrics": ["scoring_accuracy"],
"search_space": [
{
"module_name": "linear",
"embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"],
}
],
},
{
"node_type": "decision",
"target_metric": "decision_accuracy",
"search_space": [
{"module_name": "argmax"},
],
},
]


def repetition_factor(true_text: str, augmented_text: str) -> float:
"""Calculate the average ROUGE-1 F1 score between pairs of texts in true_texts and augmented_texts.

ROUGE-1 F1 is computed as:
F1 = 2 * (precision * recall) / (precision + recall)
where:
- precision = (overlap in unigrams) / (total unigrams in augmented text)
- recall = (overlap in unigrams) / (total unigrams in true text)

Args:
true_text: A ground truth text.
augmented_text: A list of augmented/generated text.

Returns:
float: The average ROUGE-1 F1 score across all pairs.

Raises:
ValueError: If the lengths of true_texts and augmented_texts differ.
"""
true_tokens = true_text.split()
aug_tokens = augmented_text.split()
if not true_tokens or not aug_tokens:
return 0.0
true_counts = Counter(true_tokens)
aug_counts = Counter(aug_tokens)
# Calculate the token overlap using the minimum count for common tokens
overlap = sum(min(true_counts[token], aug_counts[token]) for token in true_counts.keys() & aug_counts.keys())
precision = overlap / len(aug_tokens)
recall = overlap / len(true_tokens)
return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)


class SemanticRecallPrecision(dspy.Signature): # type: ignore[misc]
"""Compare a system's response to the ground truth to compute its recall and precision.

If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.

Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14
"""

question: str = dspy.InputField()
ground_truth: str = dspy.InputField()
system_response: str = dspy.InputField()
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")


class AugmentSemanticF1(dspy.Module): # type: ignore[misc]
"""Compare a system's response to the ground truth to compute its recall and precision.

Adapted from https://dspy.ai/api/evaluation/SemanticF1/
"""

def __init__(self, threshold: float = 0.66) -> None:
"""Initialize the AugmentSemanticF1.

Args:
threshold: Threshold for the boolean output.
"""
self.threshold = threshold
self.module = dspy.ChainOfThought(SemanticRecallPrecision)

def forward(
self, example: dspy.Example, pred: dspy.Prediction, trace: list[dspy.Prediction] | None = None
) -> float | bool:
"""Compute the score for the given example and prediction.

Uses SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.

Args:
example: Question and ground truth.
pred: System response.
trace: Predictions from previous iterations.

Returns:
The final score or a boolean based on the threshold.
"""
# Compute base scores using the existing semantic metric.
scores = self.module(
question=example.text, ground_truth=example.augmented_text, system_response=pred.augmented_text
)
base_score = f1_score(scores.precision, scores.recall)

# Compute repetition penalty factor.
penalty = repetition_factor(example.augmented_text, pred.augmented_text)
# length_penalty = len(example.augmented_text) / len(pred.augmented_text)
# Apply penalty to the base score.
final_score = base_score * penalty # * length_penalty
# Return the final score, or a boolean based on the threshold if trace is provided.
return final_score if trace is None else final_score >= self.threshold # type: ignore[no-any-return]


class AugmentationSignature(dspy.Signature): # type: ignore[misc]
"""Signature for text generation for augmentation task."""

text: str = dspy.InputField(desc="Text to augment. Your task to paraphrase this text.")
augmented_text: str = dspy.OutputField(desc="Augmented text. This should be on same language as text")


class DSPYIncrementalUtteranceEvolver:
"""Incremental evolutionary strategy to augmenting utterances using DSPy.

Implements an evolutionary strategy to augment utterances using DSPy. This module would augment the utterances.
For ground truth utterances, it would generate new utterances and evaluate them using the pipeline.

For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
"""

def __init__(
self,
model: str,
api_base: str | None = None,
temperature: float = 0.0,
max_tokens: int = 1000,
seed: int = 42,
search_space: str | None = None,
) -> None:
"""Initialize the DSPYIncrementalUtteranceEvolver.

Args:
model: Model name. This should follow naming schema from litellm.
https://docs.litellm.ai/docs/providers
api_base: API base URL. Some models require this.
temperature: Sampling temperature. 0.0 is default from dspy LM.
max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
seed: Random seed for reproducibility.
search_space: Search space for the pipeline.
"""
self.search_space = search_space or DEFAULT_SEARCH_SPACE
random.seed(seed)

llm = dspy.LM(
model,
api_base=api_base,
model_type="chat",
temperature=temperature,
max_tokens=max_tokens,
)
dspy.settings.configure(lm=llm)
self.generator = dspy.ChainOfThoughtWithHint(AugmentationSignature)

def augment(
self,
dataset: Dataset,
split_name: str = Split.TEST,
n_evolutions: int = 3,
update_split: bool = True,
mipro_init_params: dict[str, Any] | None = None,
mipro_compile_params: dict[str, Any] | None = None,
save_path: Path | str = "evolution_config",
) -> HFDataset:
"""Augment the dataset using the evolutionary strategy.

Args:
dataset: The dataset to augment.
split_name: The name of the split to augment.
n_evolutions: Number of evolutions to perform.
update_split: Whether to update the split with the augmented data.
mipro_init_params: Parameters for the MIPROv2 augmentation.
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters
mipro_compile_params: Parameters for the MIPROv2 compilation.
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters
save_path: Path to save the generated samples. Defaults to "evolution_config".

Returns:
The augmented dataset.
"""
best_result = 0
merge_dataset = copy.deepcopy(dataset)
generated_samples = []
original_split = dataset[split_name]
if mipro_init_params is None:
mipro_init_params = {}
if mipro_compile_params is None:
mipro_compile_params = {}

if isinstance(save_path, str):
save_path = Path(save_path)

if not save_path.exists():
save_path.mkdir(parents=True)

dspy_dataset = [
dspy.Example(
text=sample[Dataset.utterance_feature],
augmented_text=sample[Dataset.utterance_feature], # Use original as reference
).with_inputs(
"text",
)
for sample in original_split
]

for i in range(n_evolutions):
metric = AugmentSemanticF1()

optimizer = dspy.MIPROv2(metric=metric, **mipro_init_params)

optimized_module = optimizer.compile(self.generator, trainset=dspy_dataset, **mipro_compile_params)

optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True)
optimized_module.save(
(save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False
)
# Generate new samples
new_samples = []
for sample in original_split:
utterance = sample[Dataset.utterance_feature]
label = sample[Dataset.label_feature]
prediction = optimized_module(text=utterance)
new_samples.append({Dataset.label_feature: label, Dataset.utterance_feature: prediction.augmented_text})

new_samples_dataset = HFDataset.from_list(new_samples)
merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset])
generated_samples.append(new_samples_dataset)

# Check if the new samples improve the model
pipeline_optimizer = Pipeline.from_search_space(self.search_space)
ctx = pipeline_optimizer.fit(merge_dataset)
results = ctx.optimization_info.dump_evaluation_results()
decision_metric = results["metrics"]["decision"][0]
msg = f"Evolution {i} decision metric: {decision_metric}"
logger.info(msg)

if decision_metric > best_result:
best_result = decision_metric
msg = f"Evolution {i} is the best so far."
logger.info(msg)
else:
break

if update_split:
dataset[split_name] = merge_dataset[split_name]

return concatenate_datasets(generated_samples)
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ ipykernel = "^6.29.5"
tensorboardx = "^2.6.2.2"
sphinx-multiversion = "^0.2.4"

[tool.poetry.group.dspy]
optional = true


[tool.poetry.group.dspy.dependencies]
dspy = "^2.6.5"


[tool.ruff]
line-length = 120
indent-width = 4
Expand Down Expand Up @@ -194,6 +202,8 @@ module = [
"torch.utils.tensorboard",
"tensorboardX",
"wandb",
"dspy",
"dspy.evaluate.auto_evaluation",
]
ignore_missing_imports = true

Expand Down
Loading