1- """
2- Evolutionary strategy to augmenting utterances.
3- """
1+ """Evolutionary strategy to augmenting utterances."""
42
53import copy
64import logging
75import random
86from collections import Counter
97from pathlib import Path
10- from typing import Any
118
129import dspy
1310from datasets import Dataset as HFDataset
1411from datasets import concatenate_datasets
15-
16- # from dspy.evaluate import CompleteAndGrounded, SemanticF1, answer_exact_match
1712from dspy .evaluate .auto_evaluation import f1_score
1813
1914from autointent import Dataset , Pipeline
2217logging .basicConfig (level = logging .INFO )
2318logger = logging .getLogger (__name__ )
2419
25- SEARCH_SPACE = [
20+ DEFAULT_SEARCH_SPACE = [
2621 {
2722 "node_type" : "scoring" ,
2823 "target_metric" : "scoring_roc_auc" ,
@@ -74,17 +69,16 @@ def repetition_factor(true_text: str, augmented_text: str) -> float:
7469 overlap = sum (min (true_counts [token ], aug_counts [token ]) for token in true_counts .keys () & aug_counts .keys ())
7570 precision = overlap / len (aug_tokens )
7671 recall = overlap / len (true_tokens )
77- if precision + recall == 0 :
78- f1 = 0.0
79- else :
80- f1 = 2 * precision * recall / (precision + recall )
81- return f1
72+ return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall )
8273
8374
8475class SemanticRecallPrecision (dspy .Signature ):
8576 """
8677 Compare a system's response to the ground truth to compute its recall and precision.
78+
8779 If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
80+
81+ Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14
8882 """
8983
9084 # Copied from dspy
@@ -97,14 +91,37 @@ class SemanticRecallPrecision(dspy.Signature):
9791
9892
9993class AugmentSemanticF1 (dspy .Module ):
100- # adapted SemanticF1
101- def __init__ (self , threshold : float = 0.66 , ** kwargs : Any ) -> None :
94+ """Compare a system's response to the ground truth to compute its recall and precision.
95+
96+ Adapted from https://dspy.ai/api/evaluation/SemanticF1/
97+ """
98+
99+ def __init__ (self , threshold : float = 0.66 ) -> None :
100+ """
101+ Initialize the AugmentSemanticF1.
102+
103+ Args:
104+ threshold: Threshold for the boolean output.
105+ """
102106 self .threshold = threshold
103107 self .module = dspy .ChainOfThought (SemanticRecallPrecision )
104108
105109 def forward (
106110 self , example : dspy .Example , pred : dspy .Prediction , trace : list [dspy .Prediction ] | None = None
107111 ) -> float | bool :
112+ """
113+ Compute the score for the given example and prediction.
114+
115+ Uses SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
116+
117+ Args:
118+ example: Question and ground truth.
119+ pred: System response.
120+ trace: Predictions from previous iterations.
121+
122+ Returns:
123+ The final score or a boolean based on the threshold.
124+ """
108125 # Compute base scores using the existing semantic metric.
109126 scores = self .module (question = example .question , ground_truth = example .response , system_response = pred .response )
110127 base_score = f1_score (scores .precision , scores .recall )
@@ -119,87 +136,112 @@ def forward(
119136
120137
121138class DSPYIncrementalUtteranceEvolver :
122- """Incremental evolutionary strategy to augmenting utterances using DSPy."""
139+ """Incremental evolutionary strategy to augmenting utterances using DSPy.
140+
141+ Implements an evolutionary strategy to augment utterances using DSPy. This module would augment the utterances.
142+ For ground truth utterances, it would generate new utterances and evaluate them using the pipeline.
143+
144+ For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
145+ """
123146
124147 def __init__ (
125148 self ,
126- seed : int = 0 ,
149+ model : str ,
150+ api_base : str | None = None ,
151+ temperature : float = 0.0 ,
152+ max_tokens : int = 1000 ,
153+ seed : int = 42 ,
127154 search_space : str | None = None ,
128155 ) -> None :
129- """Initialize."""
130- self .search_space = self ._choose_search_space (search_space )
156+ """
157+ Initialize the DSPYIncrementalUtteranceEvolver.
158+
159+ Args:
160+ model: Model name. This should follow naming schema from litellm.
161+ https://docs.litellm.ai/docs/providers
162+ api_base: API base URL. Some models require this.
163+ temperature: Sampling temperature. 0.0 is default from dspy LM.
164+ max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
165+ seed: Random seed for reproducibility.
166+ search_space: Search space for the pipeline.
167+ """
168+ self .search_space = search_space or DEFAULT_SEARCH_SPACE
131169 random .seed (seed )
132170
133- turbo = dspy .LM (
134- "hosted_vllm/x5-airun-medium-coder-prod" ,
135- api_base = "http://mn-rtx01.x5.ru:8000/v1" ,
136- # api_key="test",
171+ llm = dspy .LM (
172+ model ,
173+ api_base = api_base ,
137174 model_type = "chat" ,
175+ temperature = temperature ,
176+ max_tokens = max_tokens ,
138177 )
139- dspy .settings .configure (lm = turbo )
140- # self.generator = dspy.ChainOfThought("text, n_examples -> augmented_texts: list[str]")
178+ dspy .settings .configure (lm = llm )
141179 # input should be question and response is augmented. question and response required for metric
142180 self .generator = dspy .ChainOfThought ("question -> response: str" )
143181
144- def _choose_search_space (self , search_space : str | None ) -> list [dict [str , Any ]] | Path | str :
145- if search_space is None :
146- return SEARCH_SPACE
147- return search_space
148-
149182 def augment (
150183 self ,
151184 dataset : Dataset ,
152185 split_name : str = Split .TEST ,
153- n_evolutions : int = 1 ,
186+ n_evolutions : int = 3 ,
154187 update_split : bool = True ,
155- batch_size : int = 4 ,
188+ mipro_init_params : dict | None = None ,
189+ mipro_compile_params : dict | None = None ,
190+ save_path : Path | str = "evolution_config" ,
156191 ) -> HFDataset :
157192 """
158- Augment dataset split using DSPy with incremental optimization.
193+ Augment the dataset using the evolutionary strategy.
194+
195+ Args:
196+ dataset: The dataset to augment.
197+ split_name: The name of the split to augment.
198+ n_evolutions: Number of evolutions to perform.
199+ update_split: Whether to update the split with the augmented data.
200+ mipro_init_params: Parameters for the MIPROv2 augmentation.
201+ Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters
202+ mipro_compile_params: Parameters for the MIPROv2 compilation.
203+ Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters
204+ save_path: Path to save the generated samples. Defaults to "evolution_config".
205+
206+ Returns:
207+ The augmented dataset.
159208 """
160209 best_result = 0
161210 merge_dataset = copy .deepcopy (dataset )
162211 generated_samples = []
163212 original_split = dataset [split_name ]
213+ if mipro_init_params is None :
214+ mipro_init_params = {}
215+ if mipro_compile_params is None :
216+ mipro_compile_params = {}
217+
218+ if isinstance (save_path , str ):
219+ save_path = Path (save_path )
220+
221+ if not save_path .exists ():
222+ save_path .mkdir (parents = True )
164223
165224 dspy_dataset = [
166225 dspy .Example (
167226 question = sample [Dataset .utterance_feature ],
168- # n_examples=1,
169227 response = sample [Dataset .utterance_feature ], # Use original as reference
170228 ).with_inputs (
171229 "question" ,
172- # "n_examples"
173230 )
174231 for sample in original_split
175232 ]
176233
177234 for i in range (n_evolutions ):
178235 metric = AugmentSemanticF1 ()
179236
180- optimizer = dspy .MIPROv2 (
181- metric = metric , # SemanticF1
182- # auto="medium", # can be low, medium, high. this setting will override params in compile
183- # num_threads=batch_size,
184- # log_dir="logs",
185- )
186- optimized_module = optimizer .compile (
187- self .generator ,
188- trainset = dspy_dataset ,
189- requires_permission_to_run = False ,
190- minibatch = False ,
191- # max_bootstrapped_demos=4,
192- # max_labeled_demos=4,
193- num_trials = 5 ,
237+ optimizer = dspy .MIPROv2 (metric = metric , ** mipro_init_params )
238+
239+ optimized_module = optimizer .compile (self .generator , trainset = dspy_dataset , ** mipro_compile_params )
240+
241+ optimized_module .save ((save_path / f"evolution_{ i } " ).as_posix (), save_program = True )
242+ optimized_module .save (
243+ (save_path / f"evolution_{ i } " / "generator_state.json" ).as_posix (), save_program = False
194244 )
195- # evaluate(optimized_module)
196- # try:
197- self .generator .save ("generator/" , save_program = True )
198- # should be dir + file *.json or *.pkl
199- self .generator .save ("generator/generator_state.json" , save_program = False )
200-
201- optimized_module .save ("optimized_module" , save_program = True )
202- optimized_module .save ("optimized_module/optimized_module.json" , save_program = False )
203245 # Generate new samples
204246 new_samples = []
205247 for sample in original_split :
@@ -219,22 +261,17 @@ def augment(
219261 ctx = pipeline_optimizer .fit (merge_dataset )
220262 results = ctx .optimization_info .dump_evaluation_results ()
221263 decision_metric = results ["metrics" ]["decision" ][0 ]
264+ msg = f"Evolution { i } decision metric: { decision_metric } "
265+ logger .info (msg )
222266
223267 if decision_metric > best_result :
224268 best_result = decision_metric
269+ msg = f"Evolution { i } is the best so far."
270+ logger .info (msg )
225271 else :
226272 break
227273
228274 if update_split :
229275 dataset [split_name ] = merge_dataset [split_name ]
230276
231277 return concatenate_datasets (generated_samples )
232-
233-
234- if __name__ == "__main__" :
235- from autointent import Dataset
236-
237- # Example usage
238- dataset = Dataset .from_hub ("AutoIntent/clinc150_subset" )
239- evolver = DSPYIncrementalUtteranceEvolver (seed = 42 , search_space = None )
240- augmented_dataset = evolver .augment (dataset , split_name = Split .TEST , n_evolutions = 2 )
0 commit comments