|
| 1 | +"""Evolutionary strategy to augmenting utterances.""" |
| 2 | + |
| 3 | +import copy |
| 4 | +import logging |
| 5 | +import random |
| 6 | +from collections import Counter |
| 7 | +from pathlib import Path |
| 8 | +from typing import Any |
| 9 | + |
| 10 | +try: |
| 11 | + import dspy |
| 12 | +except ImportError: |
| 13 | + import_error = "dspy is not installed. Please install it with `pip install dspy` or `pip install autointent[dspy]`." |
| 14 | + raise ImportError(import_error) from None |
| 15 | + |
| 16 | +from datasets import Dataset as HFDataset |
| 17 | +from datasets import concatenate_datasets |
| 18 | +from dspy.evaluate.auto_evaluation import f1_score |
| 19 | + |
| 20 | +from autointent import Dataset, Pipeline |
| 21 | +from autointent.custom_types import Split |
| 22 | + |
| 23 | +logging.basicConfig(level=logging.INFO) |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +DEFAULT_SEARCH_SPACE = [ |
| 27 | + { |
| 28 | + "node_type": "scoring", |
| 29 | + "target_metric": "scoring_roc_auc", |
| 30 | + "metrics": ["scoring_accuracy"], |
| 31 | + "search_space": [ |
| 32 | + { |
| 33 | + "module_name": "linear", |
| 34 | + "embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"], |
| 35 | + } |
| 36 | + ], |
| 37 | + }, |
| 38 | + { |
| 39 | + "node_type": "decision", |
| 40 | + "target_metric": "decision_accuracy", |
| 41 | + "search_space": [ |
| 42 | + {"module_name": "argmax"}, |
| 43 | + ], |
| 44 | + }, |
| 45 | +] |
| 46 | + |
| 47 | + |
| 48 | +def repetition_factor(true_text: str, augmented_text: str) -> float: |
| 49 | + """Calculate the average ROUGE-1 F1 score between pairs of texts in true_texts and augmented_texts. |
| 50 | +
|
| 51 | + ROUGE-1 F1 is computed as: |
| 52 | + F1 = 2 * (precision * recall) / (precision + recall) |
| 53 | + where: |
| 54 | + - precision = (overlap in unigrams) / (total unigrams in augmented text) |
| 55 | + - recall = (overlap in unigrams) / (total unigrams in true text) |
| 56 | +
|
| 57 | + Args: |
| 58 | + true_text: A ground truth text. |
| 59 | + augmented_text: A list of augmented/generated text. |
| 60 | +
|
| 61 | + Returns: |
| 62 | + float: The average ROUGE-1 F1 score across all pairs. |
| 63 | +
|
| 64 | + Raises: |
| 65 | + ValueError: If the lengths of true_texts and augmented_texts differ. |
| 66 | + """ |
| 67 | + true_tokens = true_text.split() |
| 68 | + aug_tokens = augmented_text.split() |
| 69 | + if not true_tokens or not aug_tokens: |
| 70 | + return 0.0 |
| 71 | + true_counts = Counter(true_tokens) |
| 72 | + aug_counts = Counter(aug_tokens) |
| 73 | + # Calculate the token overlap using the minimum count for common tokens |
| 74 | + overlap = sum(min(true_counts[token], aug_counts[token]) for token in true_counts.keys() & aug_counts.keys()) |
| 75 | + precision = overlap / len(aug_tokens) |
| 76 | + recall = overlap / len(true_tokens) |
| 77 | + return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall) |
| 78 | + |
| 79 | + |
| 80 | +class SemanticRecallPrecision(dspy.Signature): # type: ignore[misc] |
| 81 | + """Compare a system's response to the ground truth to compute its recall and precision. |
| 82 | +
|
| 83 | + If asked to reason, enumerate key ideas in each response, and whether they are present in the other response. |
| 84 | +
|
| 85 | + Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14 |
| 86 | + """ |
| 87 | + |
| 88 | + question: str = dspy.InputField() |
| 89 | + ground_truth: str = dspy.InputField() |
| 90 | + system_response: str = dspy.InputField() |
| 91 | + recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response") |
| 92 | + precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth") |
| 93 | + |
| 94 | + |
| 95 | +class AugmentSemanticF1(dspy.Module): # type: ignore[misc] |
| 96 | + """Compare a system's response to the ground truth to compute its recall and precision. |
| 97 | +
|
| 98 | + Adapted from https://dspy.ai/api/evaluation/SemanticF1/ |
| 99 | + """ |
| 100 | + |
| 101 | + def __init__(self, threshold: float = 0.66) -> None: |
| 102 | + """Initialize the AugmentSemanticF1. |
| 103 | +
|
| 104 | + Args: |
| 105 | + threshold: Threshold for the boolean output. |
| 106 | + """ |
| 107 | + self.threshold = threshold |
| 108 | + self.module = dspy.ChainOfThought(SemanticRecallPrecision) |
| 109 | + |
| 110 | + def forward( |
| 111 | + self, example: dspy.Example, pred: dspy.Prediction, trace: list[dspy.Prediction] | None = None |
| 112 | + ) -> float | bool: |
| 113 | + """Compute the score for the given example and prediction. |
| 114 | +
|
| 115 | + Uses SemanticF1 as the base metric with a ROUGE-1 as repetition penalty. |
| 116 | +
|
| 117 | + Args: |
| 118 | + example: Question and ground truth. |
| 119 | + pred: System response. |
| 120 | + trace: Predictions from previous iterations. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + The final score or a boolean based on the threshold. |
| 124 | + """ |
| 125 | + # Compute base scores using the existing semantic metric. |
| 126 | + scores = self.module( |
| 127 | + question=example.text, ground_truth=example.augmented_text, system_response=pred.augmented_text |
| 128 | + ) |
| 129 | + base_score = f1_score(scores.precision, scores.recall) |
| 130 | + |
| 131 | + # Compute repetition penalty factor. |
| 132 | + penalty = repetition_factor(example.augmented_text, pred.augmented_text) |
| 133 | + # length_penalty = len(example.augmented_text) / len(pred.augmented_text) |
| 134 | + # Apply penalty to the base score. |
| 135 | + final_score = base_score * penalty # * length_penalty |
| 136 | + # Return the final score, or a boolean based on the threshold if trace is provided. |
| 137 | + return final_score if trace is None else final_score >= self.threshold # type: ignore[no-any-return] |
| 138 | + |
| 139 | + |
| 140 | +class AugmentationSignature(dspy.Signature): # type: ignore[misc] |
| 141 | + """Signature for text generation for augmentation task.""" |
| 142 | + |
| 143 | + text: str = dspy.InputField(desc="Text to augment. Your task to paraphrase this text.") |
| 144 | + augmented_text: str = dspy.OutputField(desc="Augmented text. This should be on same language as text") |
| 145 | + |
| 146 | + |
| 147 | +class DSPYIncrementalUtteranceEvolver: |
| 148 | + """Incremental evolutionary strategy to augmenting utterances using DSPy. |
| 149 | +
|
| 150 | + Implements an evolutionary strategy to augment utterances using DSPy. This module would augment the utterances. |
| 151 | + For ground truth utterances, it would generate new utterances and evaluate them using the pipeline. |
| 152 | +
|
| 153 | + For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty. |
| 154 | + """ |
| 155 | + |
| 156 | + def __init__( |
| 157 | + self, |
| 158 | + model: str, |
| 159 | + api_base: str | None = None, |
| 160 | + temperature: float = 0.0, |
| 161 | + max_tokens: int = 1000, |
| 162 | + seed: int = 42, |
| 163 | + search_space: str | None = None, |
| 164 | + ) -> None: |
| 165 | + """Initialize the DSPYIncrementalUtteranceEvolver. |
| 166 | +
|
| 167 | + Args: |
| 168 | + model: Model name. This should follow naming schema from litellm. |
| 169 | + https://docs.litellm.ai/docs/providers |
| 170 | + api_base: API base URL. Some models require this. |
| 171 | + temperature: Sampling temperature. 0.0 is default from dspy LM. |
| 172 | + max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM. |
| 173 | + seed: Random seed for reproducibility. |
| 174 | + search_space: Search space for the pipeline. |
| 175 | + """ |
| 176 | + self.search_space = search_space or DEFAULT_SEARCH_SPACE |
| 177 | + random.seed(seed) |
| 178 | + |
| 179 | + llm = dspy.LM( |
| 180 | + model, |
| 181 | + api_base=api_base, |
| 182 | + model_type="chat", |
| 183 | + temperature=temperature, |
| 184 | + max_tokens=max_tokens, |
| 185 | + ) |
| 186 | + dspy.settings.configure(lm=llm) |
| 187 | + self.generator = dspy.ChainOfThoughtWithHint(AugmentationSignature) |
| 188 | + |
| 189 | + def augment( |
| 190 | + self, |
| 191 | + dataset: Dataset, |
| 192 | + split_name: str = Split.TEST, |
| 193 | + n_evolutions: int = 3, |
| 194 | + update_split: bool = True, |
| 195 | + mipro_init_params: dict[str, Any] | None = None, |
| 196 | + mipro_compile_params: dict[str, Any] | None = None, |
| 197 | + save_path: Path | str = "evolution_config", |
| 198 | + ) -> HFDataset: |
| 199 | + """Augment the dataset using the evolutionary strategy. |
| 200 | +
|
| 201 | + Args: |
| 202 | + dataset: The dataset to augment. |
| 203 | + split_name: The name of the split to augment. |
| 204 | + n_evolutions: Number of evolutions to perform. |
| 205 | + update_split: Whether to update the split with the augmented data. |
| 206 | + mipro_init_params: Parameters for the MIPROv2 augmentation. |
| 207 | + Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters |
| 208 | + mipro_compile_params: Parameters for the MIPROv2 compilation. |
| 209 | + Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters |
| 210 | + save_path: Path to save the generated samples. Defaults to "evolution_config". |
| 211 | +
|
| 212 | + Returns: |
| 213 | + The augmented dataset. |
| 214 | + """ |
| 215 | + best_result = 0 |
| 216 | + merge_dataset = copy.deepcopy(dataset) |
| 217 | + generated_samples = [] |
| 218 | + original_split = dataset[split_name] |
| 219 | + if mipro_init_params is None: |
| 220 | + mipro_init_params = {} |
| 221 | + if mipro_compile_params is None: |
| 222 | + mipro_compile_params = {} |
| 223 | + |
| 224 | + if isinstance(save_path, str): |
| 225 | + save_path = Path(save_path) |
| 226 | + |
| 227 | + if not save_path.exists(): |
| 228 | + save_path.mkdir(parents=True) |
| 229 | + |
| 230 | + dspy_dataset = [ |
| 231 | + dspy.Example( |
| 232 | + text=sample[Dataset.utterance_feature], |
| 233 | + augmented_text=sample[Dataset.utterance_feature], # Use original as reference |
| 234 | + ).with_inputs( |
| 235 | + "text", |
| 236 | + ) |
| 237 | + for sample in original_split |
| 238 | + ] |
| 239 | + |
| 240 | + for i in range(n_evolutions): |
| 241 | + metric = AugmentSemanticF1() |
| 242 | + |
| 243 | + optimizer = dspy.MIPROv2(metric=metric, **mipro_init_params) |
| 244 | + |
| 245 | + optimized_module = optimizer.compile(self.generator, trainset=dspy_dataset, **mipro_compile_params) |
| 246 | + |
| 247 | + optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True) |
| 248 | + optimized_module.save( |
| 249 | + (save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False |
| 250 | + ) |
| 251 | + # Generate new samples |
| 252 | + new_samples = [] |
| 253 | + for sample in original_split: |
| 254 | + utterance = sample[Dataset.utterance_feature] |
| 255 | + label = sample[Dataset.label_feature] |
| 256 | + prediction = optimized_module(text=utterance) |
| 257 | + new_samples.append({Dataset.label_feature: label, Dataset.utterance_feature: prediction.augmented_text}) |
| 258 | + |
| 259 | + new_samples_dataset = HFDataset.from_list(new_samples) |
| 260 | + merge_dataset[split_name] = concatenate_datasets([merge_dataset[split_name], new_samples_dataset]) |
| 261 | + generated_samples.append(new_samples_dataset) |
| 262 | + |
| 263 | + # Check if the new samples improve the model |
| 264 | + pipeline_optimizer = Pipeline.from_search_space(self.search_space) |
| 265 | + ctx = pipeline_optimizer.fit(merge_dataset) |
| 266 | + results = ctx.optimization_info.dump_evaluation_results() |
| 267 | + decision_metric = results["metrics"]["decision"][0] |
| 268 | + msg = f"Evolution {i} decision metric: {decision_metric}" |
| 269 | + logger.info(msg) |
| 270 | + |
| 271 | + if decision_metric > best_result: |
| 272 | + best_result = decision_metric |
| 273 | + msg = f"Evolution {i} is the best so far." |
| 274 | + logger.info(msg) |
| 275 | + else: |
| 276 | + break |
| 277 | + |
| 278 | + if update_split: |
| 279 | + dataset[split_name] = merge_dataset[split_name] |
| 280 | + |
| 281 | + return concatenate_datasets(generated_samples) |
0 commit comments