Skip to content

Commit 09521a7

Browse files
committed
some working version
1 parent e5e3ead commit 09521a7

File tree

1 file changed

+111
-56
lines changed

1 file changed

+111
-56
lines changed

autointent/generation/utterances/evolution/dspy_evolver.py

Lines changed: 111 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
"""
22
Evolutionary strategy to augmenting utterances.
3-
4-
Deeply inspired by DeepEval evolutions.
53
"""
4+
65
import copy
7-
import os
6+
import logging
87
import random
8+
from collections import Counter
99
from pathlib import Path
1010
from typing import Any
1111

1212
import dspy
13-
from datasets import Dataset as HFDataset, concatenate_datasets
14-
from dspy.evaluate import SemanticF1
15-
import logging
13+
from datasets import Dataset as HFDataset
14+
from datasets import concatenate_datasets
15+
16+
# from dspy.evaluate import CompleteAndGrounded, SemanticF1, answer_exact_match
17+
from dspy.evaluate.auto_evaluation import f1_score
1618

1719
from autointent import Dataset, Pipeline
1820
from autointent.custom_types import Split
1921

22+
logging.basicConfig(level=logging.INFO)
2023
logger = logging.getLogger(__name__)
21-
logging.basicConfig(level=logging.DEBUG)
2224

2325
SEARCH_SPACE = [
2426
{
@@ -42,31 +44,80 @@
4244
]
4345

4446

45-
# Define a DSPy signature for text augmentation.
46-
class TextAugmentSignature(dspy.Signature):
47-
text: str = dspy.InputField()
48-
# n_examples: int = dspy.InputField()
49-
augmented_texts: list[str] = dspy.OutputField(
50-
desc="List of augmented texts that preserve the original meaning but use varied phrasing."
51-
)
47+
def repetition_factor(true_text: str, augmented_text: str) -> float:
48+
"""
49+
Calculate the average ROUGE-1 F1 score between pairs of texts in true_texts and augmented_texts.
50+
51+
ROUGE-1 F1 is computed as:
52+
F1 = 2 * (precision * recall) / (precision + recall)
53+
where:
54+
- precision = (overlap in unigrams) / (total unigrams in augmented text)
55+
- recall = (overlap in unigrams) / (total unigrams in true text)
56+
57+
Args:
58+
true_text: A ground truth text.
59+
augmented_text: A list of augmented/generated text.
60+
61+
Returns:
62+
float: The average ROUGE-1 F1 score across all pairs.
63+
64+
Raises:
65+
ValueError: If the lengths of true_texts and augmented_texts differ.
66+
"""
67+
true_tokens = true_text.split()
68+
aug_tokens = augmented_text.split()
69+
if not true_tokens or not aug_tokens:
70+
return 0.0
71+
true_counts = Counter(true_tokens)
72+
aug_counts = Counter(aug_tokens)
73+
# Calculate the token overlap using the minimum count for common tokens
74+
overlap = sum(min(true_counts[token], aug_counts[token]) for token in true_counts.keys() & aug_counts.keys())
75+
precision = overlap / len(aug_tokens)
76+
recall = overlap / len(true_tokens)
77+
if precision + recall == 0:
78+
f1 = 0.0
79+
else:
80+
f1 = 2 * precision * recall / (precision + recall)
81+
return f1
82+
83+
84+
class SemanticRecallPrecision(dspy.Signature):
85+
"""
86+
Compare a system's response to the ground truth to compute its recall and precision.
87+
If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
88+
"""
89+
90+
# Copied from dspy
91+
92+
question: str = dspy.InputField()
93+
ground_truth: str = dspy.InputField()
94+
system_response: str = dspy.InputField()
95+
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
96+
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")
97+
98+
99+
class AugmentSemanticF1(dspy.Module):
100+
# adapted SemanticF1
101+
def __init__(self, threshold: float = 0.66, **kwargs: Any) -> None:
102+
self.threshold = threshold
103+
self.module = dspy.ChainOfThought(SemanticRecallPrecision)
104+
105+
def forward(
106+
self, example: dspy.Example, pred: dspy.Prediction, trace: list[dspy.Prediction] | None = None
107+
) -> float | bool:
108+
# Compute base scores using the existing semantic metric.
109+
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
110+
base_score = f1_score(scores.precision, scores.recall)
111+
112+
# Compute repetition penalty factor.
113+
penalty = repetition_factor(example.response, pred.response)
114+
115+
# Apply penalty to the base score.
116+
final_score = base_score * penalty
117+
# Return the final score, or a boolean based on the threshold if trace is provided.
118+
return final_score if trace is None else final_score >= self.threshold
52119

53120

54-
# # Define a DSPy module that implements text augmentation.
55-
# class TextAugmenter(dspy.Module):
56-
# def __init__(self) -> None:
57-
# # Here, we use a ChainOfThought module with the defined signature.
58-
# # The module is responsible for "thinking through" and generating multiple text variants.
59-
# super().__init__()
60-
# self.generator = dspy.ChainOfThought("text, n_examples -> augmented_texts")
61-
#
62-
# def forward(self, text: str, n_examples: int) -> dspy.Prediction:
63-
# # Invoke the underlying generator with the input text and desired number of examples.
64-
# return self.generator(text=text, n_examples=n_examples)
65-
66-
67-
os.environ['MISTRAL_API_KEY'] = ""
68-
os.environ["OPENROUTER_API_KEY"] = ""
69-
70121
class DSPYIncrementalUtteranceEvolver:
71122
"""Incremental evolutionary strategy to augmenting utterances using DSPy."""
72123

@@ -78,15 +129,17 @@ def __init__(
78129
"""Initialize."""
79130
self.search_space = self._choose_search_space(search_space)
80131
random.seed(seed)
81-
# full list of providers
132+
82133
turbo = dspy.LM(
83-
...,
84-
model_type='chat'
134+
"hosted_vllm/x5-airun-medium-coder-prod",
135+
api_base="http://mn-rtx01.x5.ru:8000/v1",
136+
# api_key="test",
137+
model_type="chat",
85138
)
86139
dspy.settings.configure(lm=turbo)
87140
# self.generator = dspy.ChainOfThought("text, n_examples -> augmented_texts: list[str]")
88141
# input should be question and response is augmented. question and response required for metric
89-
self.generator = dspy.ChainOfThought("question -> response: list[str]")
142+
self.generator = dspy.ChainOfThought("question -> response: str")
90143

91144
def _choose_search_space(self, search_space: str | None) -> list[dict[str, Any]] | Path | str:
92145
if search_space is None:
@@ -113,43 +166,48 @@ def augment(
113166
dspy.Example(
114167
question=sample[Dataset.utterance_feature],
115168
# n_examples=1,
116-
response=sample[Dataset.utterance_feature] # Use original as reference
169+
response=sample[Dataset.utterance_feature], # Use original as reference
117170
).with_inputs(
118171
"question",
119172
# "n_examples"
120173
)
121174
for sample in original_split
122175
]
123176

124-
for _ in range(n_evolutions):
125-
metric = SemanticF1()
177+
for i in range(n_evolutions):
178+
metric = AugmentSemanticF1()
126179

127180
optimizer = dspy.MIPROv2(
128-
metric=metric,
129-
auto="medium",
130-
num_threads=batch_size,
131-
log_dir="logs",
181+
metric=metric, # SemanticF1
182+
# auto="medium", # can be low, medium, high. this setting will override params in compile
183+
# num_threads=batch_size,
184+
# log_dir="logs",
132185
)
133186
optimized_module = optimizer.compile(
134187
self.generator,
135188
trainset=dspy_dataset,
136189
requires_permission_to_run=False,
137-
max_bootstrapped_demos=4,
138-
max_labeled_demos=4
190+
minibatch=False,
191+
# max_bootstrapped_demos=4,
192+
# max_labeled_demos=4,
193+
num_trials=5,
139194
)
195+
# evaluate(optimized_module)
196+
# try:
197+
self.generator.save("generator/", save_program=True)
198+
# should be dir + file *.json or *.pkl
199+
self.generator.save("generator/generator_state.json", save_program=False)
200+
201+
optimized_module.save("optimized_module", save_program=True)
202+
optimized_module.save("optimized_module/optimized_module.json", save_program=False)
203+
# Generate new samples
140204
new_samples = []
141205
for sample in original_split:
142206
utterance = sample[Dataset.utterance_feature]
143207
label = sample[Dataset.label_feature]
144-
prediction = optimized_module(text=utterance)
208+
prediction = optimized_module(question=utterance)
145209
new_samples.extend(
146-
[
147-
{
148-
Dataset.label_feature: label,
149-
Dataset.utterance_feature: ut
150-
}
151-
for ut in prediction.response
152-
]
210+
[{Dataset.label_feature: label, Dataset.utterance_feature: ut} for ut in prediction.response]
153211
)
154212

155213
new_samples_dataset = HFDataset.from_list(new_samples)
@@ -178,8 +236,5 @@ def augment(
178236

179237
# Example usage
180238
dataset = Dataset.from_hub("AutoIntent/clinc150_subset")
181-
evolver = DSPYIncrementalUtteranceEvolver(
182-
seed=42,
183-
search_space=None
184-
)
185-
augmented_dataset = evolver.augment(dataset, split_name=Split.TEST, n_evolutions=5)
239+
evolver = DSPYIncrementalUtteranceEvolver(seed=42, search_space=None)
240+
augmented_dataset = evolver.augment(dataset, split_name=Split.TEST, n_evolutions=2)

0 commit comments

Comments
 (0)