Skip to content

Commit ed981ce

Browse files
authored
add DSPY tutorial (#161)
* add tutorial * fix docs * fix docs
1 parent ae2a831 commit ed981ce

File tree

3 files changed

+130
-32
lines changed

3 files changed

+130
-32
lines changed

autointent/generation/utterances/evolution/dspy_evolver.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def repetition_factor(true_text: str, augmented_text: str) -> float:
6464
Raises:
6565
ValueError: If the lengths of true_texts and augmented_texts differ.
6666
"""
67-
true_tokens = true_text.split()
68-
aug_tokens = augmented_text.split()
67+
true_tokens = "".join(c for c in true_text.lower() if c.isalnum() or c.isspace()).split()
68+
aug_tokens = "".join(c for c in augmented_text.lower() if c.isalnum() or c.isspace()).split()
6969
if not true_tokens or not aug_tokens:
7070
return 0.0
7171
true_counts = Counter(true_tokens)
@@ -82,7 +82,7 @@ class SemanticRecallPrecision(dspy.Signature): # type: ignore[misc]
8282
8383
If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
8484
85-
Copied from https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14
85+
Copied from `dspy <https://github.com/stanfordnlp/dspy/blob/2957c5f998e0bc652017b6e3b1f8af34970b6f6b/dspy/evaluate/auto_evaluation.py#L4-L14>`_
8686
"""
8787

8888
question: str = dspy.InputField()
@@ -95,7 +95,7 @@ class SemanticRecallPrecision(dspy.Signature): # type: ignore[misc]
9595
class AugmentSemanticF1(dspy.Module): # type: ignore[misc]
9696
"""Compare a system's response to the ground truth to compute its recall and precision.
9797
98-
Adapted from https://dspy.ai/api/evaluation/SemanticF1/
98+
Adapted from `dspy SemanticF1 <https://dspy.ai/api/evaluation/SemanticF1/>_
9999
"""
100100

101101
def __init__(self, threshold: float = 0.66) -> None:
@@ -151,6 +151,15 @@ class DSPYIncrementalUtteranceEvolver:
151151
For ground truth utterances, it would generate new utterances and evaluate them using the pipeline.
152152
153153
For scoring generations it would use modified SemanticF1 as the base metric with a ROUGE-1 as repetition penalty.
154+
155+
Args:
156+
model: Model name. This should follow naming schema from `litellm providers <https://docs.litellm.ai/docs/providers>`_.
157+
api_base: API base URL. Some models require this.
158+
temperature: Sampling temperature. 0.0 is default from dspy LM.
159+
max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
160+
seed: Random seed for reproducibility.
161+
search_space: Search space for the pipeline.
162+
154163
"""
155164

156165
def __init__(
@@ -162,18 +171,8 @@ def __init__(
162171
seed: int = 42,
163172
search_space: str | None = None,
164173
) -> None:
165-
"""Initialize the DSPYIncrementalUtteranceEvolver.
166-
167-
Args:
168-
model: Model name. This should follow naming schema from litellm.
169-
https://docs.litellm.ai/docs/providers
170-
api_base: API base URL. Some models require this.
171-
temperature: Sampling temperature. 0.0 is default from dspy LM.
172-
max_tokens: Maximum number of tokens to generate. 1000 is default from dspy LM.
173-
seed: Random seed for reproducibility.
174-
search_space: Search space for the pipeline.
175-
"""
176-
self.search_space = search_space or DEFAULT_SEARCH_SPACE
174+
"""Initialize the DSPYIncrementalUtteranceEvolver."""
175+
self._search_space = search_space or DEFAULT_SEARCH_SPACE
177176
random.seed(seed)
178177

179178
llm = dspy.LM(
@@ -184,17 +183,17 @@ def __init__(
184183
max_tokens=max_tokens,
185184
)
186185
dspy.settings.configure(lm=llm)
187-
self.generator = dspy.ChainOfThoughtWithHint(AugmentationSignature)
186+
self._generator = dspy.ChainOfThoughtWithHint(AugmentationSignature)
188187

189-
def augment(
188+
def augment( # noqa: C901
190189
self,
191190
dataset: Dataset,
192191
split_name: str = Split.TEST,
193192
n_evolutions: int = 3,
194193
update_split: bool = True,
195194
mipro_init_params: dict[str, Any] | None = None,
196195
mipro_compile_params: dict[str, Any] | None = None,
197-
save_path: Path | str = "evolution_config",
196+
save_path: Path | str | None = None,
198197
) -> HFDataset:
199198
"""Augment the dataset using the evolutionary strategy.
200199
@@ -204,10 +203,10 @@ def augment(
204203
n_evolutions: Number of evolutions to perform.
205204
update_split: Whether to update the split with the augmented data.
206205
mipro_init_params: Parameters for the MIPROv2 augmentation.
207-
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters
206+
`Full list of parameters <https://dspy.ai/deep-dive/optimizers/miprov2/#initialization-parameters>`_
208207
mipro_compile_params: Parameters for the MIPROv2 compilation.
209-
Full list of params available at https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters
210-
save_path: Path to save the generated samples. Defaults to "evolution_config".
208+
`Full list of params available <https://dspy.ai/deep-dive/optimizers/miprov2/#compile-parameters>`_
209+
save_path: Path to save the prompt of LLM. If None is provided, it will not be saved.
211210
212211
Returns:
213212
The augmented dataset.
@@ -221,11 +220,12 @@ def augment(
221220
if mipro_compile_params is None:
222221
mipro_compile_params = {}
223222

224-
if isinstance(save_path, str):
225-
save_path = Path(save_path)
223+
if save_path is not None:
224+
if isinstance(save_path, str):
225+
save_path = Path(save_path)
226226

227-
if not save_path.exists():
228-
save_path.mkdir(parents=True)
227+
if not save_path.exists():
228+
save_path.mkdir(parents=True)
229229

230230
dspy_dataset = [
231231
dspy.Example(
@@ -242,12 +242,13 @@ def augment(
242242

243243
optimizer = dspy.MIPROv2(metric=metric, **mipro_init_params)
244244

245-
optimized_module = optimizer.compile(self.generator, trainset=dspy_dataset, **mipro_compile_params)
245+
optimized_module = optimizer.compile(self._generator, trainset=dspy_dataset, **mipro_compile_params)
246246

247-
optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True)
248-
optimized_module.save(
249-
(save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False
250-
)
247+
if save_path is not None:
248+
optimized_module.save((save_path / f"evolution_{i}").as_posix(), save_program=True)
249+
optimized_module.save(
250+
(save_path / f"evolution_{i}" / "generator_state.json").as_posix(), save_program=False
251+
)
251252
# Generate new samples
252253
new_samples = []
253254
for sample in original_split:
@@ -261,7 +262,7 @@ def augment(
261262
generated_samples.append(new_samples_dataset)
262263

263264
# Check if the new samples improve the model
264-
pipeline_optimizer = Pipeline.from_search_space(self.search_space)
265+
pipeline_optimizer = Pipeline.from_search_space(self._search_space)
265266
ctx = pipeline_optimizer.fit(merge_dataset)
266267
results = ctx.optimization_info.dump_evaluation_results()
267268
decision_metric = results["metrics"]["decision"][0]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
.. _evolutionary_strategy_augmentation:
2+
3+
DSPY Augmentation
4+
#################
5+
6+
This tutorial covers the implementation and usage of an evolutionary strategy to augment utterances using DSPy. It explains how DSPy is used, how the module functions, and how the scoring metric works.
7+
8+
.. contents:: Table of Contents
9+
:depth: 2
10+
11+
-------------
12+
What is DSPy?
13+
-------------
14+
15+
DSPy is a framework for optimizing and evaluating language models. It provides tools for defining signatures, optimizing modules, and measuring evaluation metrics. This module leverages DSPy to generate augmented utterances using an evolutionary approach.
16+
17+
---------------------
18+
How This Module Works
19+
---------------------
20+
21+
This module applies an incremental evolutionary strategy for augmenting utterances. It generates new utterances based on a given dataset and refines them using an iterative process. The generated utterances are evaluated using a scoring mechanism that includes:
22+
23+
- **SemanticF1**: Measures how well the generated utterance matches the ground truth.
24+
- **ROUGE-1 penalty**: Discourages excessive repetition.
25+
- **Pipeline Decision Metric**: Assesses whether the augmented utterances improve model performance.
26+
27+
The augmentation process runs for a specified number of evolutions, saving intermediate models and optimizing the results.
28+
29+
------------
30+
Installation
31+
------------
32+
33+
Ensure you have the required dependencies installed:
34+
35+
.. code-block:: bash
36+
37+
pip install "autointent[dspy]"
38+
39+
--------------
40+
Scoring Metric
41+
--------------
42+
43+
The scoring metric consists of:
44+
45+
1. **SemanticF1 Score**:
46+
- Computes precision and recall between system-generated utterances and ground truth by LLM.
47+
- Uses DSPy’s `SemanticRecallPrecision` module.
48+
49+
2. **Repetition Factor (ROUGE-1 Penalty)**:
50+
- Measures overlap of words between the generated and ground truth utterances.
51+
- Ensures diversity in augmentation.
52+
53+
3. **Final Score Calculation**:
54+
- `Final Score = SemanticF1 * Repetition Factor`
55+
- A higher score means better augmentation.
56+
57+
-------------
58+
Usage Example
59+
-------------
60+
61+
Before running the following code, refer to the `LiteLLM documentation <https://docs.litellm.ai/docs/providers>`_ for proper model configuration.
62+
63+
.. code-block:: python
64+
65+
import os
66+
os.environ["OPENAI_API_KEY"] = "your-api-key"
67+
68+
from autointent import Dataset
69+
from autointent.custom_types import Split
70+
71+
dataset = Dataset.from_hub("AutoIntent/clinc150_subset")
72+
evolver = DSPYIncrementalUtteranceEvolver(
73+
"openai/gpt-4o-mini"
74+
)
75+
76+
augmented_dataset = evolver.augment(
77+
dataset,
78+
split_name=Split.TEST,
79+
n_evolutions=1,
80+
mipro_init_params={
81+
"auto": "light",
82+
},
83+
mipro_compile_params={
84+
"minibatch": False,
85+
},
86+
)
87+
88+
augmented_dataset.to_csv("clinc150_dspy_augment.csv")

docs/source/user_guides.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,12 @@ User Guides
88
user_guides/index_basic_usage
99
user_guides/index_advanced_usage
1010
user_guides/index_cli_usage
11+
12+
Data augmentation tutorials
13+
---------------------------
14+
15+
.. toctree::
16+
:maxdepth: 1
17+
18+
augmentation_tutorials/dspy_augmentation
19+

0 commit comments

Comments
 (0)