Skip to content

Commit 7cd6c9c

Browse files
authored
Feat/prompt management (#109)
* refactor templates storing * make base class * update evolver * move `schemas.py` * update cli * minor bug fix * fix typing * refactor chat template for basic utterance generator * update utterance synthesizer * update `Dataset.from_hub` method * add debug messages to cli endpoint * update cli for evolver * fix shared classvar issue * add tests for basic chat template * configure explicit import from `generation.utterances` submodule * add tests for augmentation * commit to trigger gh actions * fix one doctest
1 parent 599794b commit 7cd6c9c

26 files changed

+483
-450
lines changed

autointent/_dataset/_dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,14 @@ def from_hub(cls, repo_id: str) -> "Dataset":
100100
:param repo_id: ID of the Hugging Face repository.
101101
:return: Initialized Dataset object.
102102
"""
103-
splits, intents = load_dataset(repo_id), []
103+
from ._reader import DictReader
104+
105+
splits = load_dataset(repo_id)
106+
mapping = dict(**splits)
104107
if Split.INTENTS in get_dataset_config_names(repo_id):
105-
intents = load_dataset(repo_id, Split.INTENTS)[Split.INTENTS].to_list()
106-
return cls(
107-
splits.items(),
108-
intents=[Intent.model_validate(intent) for intent in intents],
109-
)
108+
mapping["intents"] = load_dataset(repo_id, Split.INTENTS)[Split.INTENTS].to_list()
109+
110+
return DictReader().read(mapping)
110111

111112
def to_multilabel(self) -> "Dataset":
112113
"""
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .basic import SynthesizerChatTemplate, UtteranceGenerator
2+
from .evolution import AbstractEvolution, ConcreteEvolution, EvolutionChatTemplate, ReasoningEvolution, UtteranceEvolver
3+
from .generator import Generator
4+
5+
__all__ = [
6+
"AbstractEvolution",
7+
"ConcreteEvolution",
8+
"EvolutionChatTemplate",
9+
"Generator",
10+
"ReasoningEvolution",
11+
"SynthesizerChatTemplate",
12+
"UtteranceEvolver",
13+
"UtteranceGenerator",
14+
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .chat_template import SynthesizerChatTemplate
2+
from .utterance_generator import UtteranceGenerator
3+
4+
__all__ = ["SynthesizerChatTemplate", "UtteranceGenerator"]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Chat template for evolution augmentation via abstractization."""
2+
3+
import random
4+
from abc import ABC, abstractmethod
5+
from copy import deepcopy
6+
from typing import ClassVar
7+
8+
from autointent import Dataset
9+
from autointent.generation.utterances.schemas import Message, Role
10+
from autointent.schemas import Intent
11+
12+
13+
class BaseSynthesizer(ABC):
14+
"""Base class."""
15+
16+
@abstractmethod
17+
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
18+
"""Generate examples for this intent."""
19+
20+
21+
class SynthesizerChatTemplate(BaseSynthesizer):
22+
"""Chat template for generating additional examples for a given intent class."""
23+
24+
__messages: ClassVar[list[Message]] = [
25+
Message(
26+
role=Role.USER,
27+
content=(
28+
"You will be provided with a set of example utterances and the name "
29+
"of the common topic (intent name) of these utterances. "
30+
"Your task is to generate more examples that fit within the same intent name.\n\n"
31+
"Note:\n"
32+
"- You can generate similar utterances with only slot values changed\n"
33+
"- You can generate completely different utterance from the same intent name\n"
34+
"- Intent name can be missed, then you should infer from example utterances only\n"
35+
"- Example utterances can be missed, then you should infer from intent name only\n"
36+
"{extra_instructions}\n\n"
37+
"Intent name: ordering_pizza\n\n"
38+
"Example Utterances:\n"
39+
"1. I want to order a large pepperoni pizza.\n"
40+
"2. Can I get a medium cheese pizza with extra olives?\n"
41+
"3. Please deliver a small veggie pizza to my address.\n\n"
42+
"Please generate 3 more examples for the provided intent name."
43+
),
44+
),
45+
Message(
46+
role=Role.ASSISTANT,
47+
content=(
48+
"1. I'd like to order a large margherita pizza.\n"
49+
"2. Can you deliver a medium Hawaiian pizza with extra pineapple?\n"
50+
"3. Please send a small BBQ chicken pizza to my home."
51+
),
52+
),
53+
Message(
54+
role=Role.USER,
55+
content=(
56+
"Intent name: booking a hotel\n\n"
57+
"Example Utterances:\n"
58+
"1. I need to book a room for two nights in New York.\n\n"
59+
"Please generate 2 more examples for the provided intent name."
60+
),
61+
),
62+
Message(
63+
role=Role.ASSISTANT,
64+
content=(
65+
"1. Can you reserve a deluxe room for my trip to Tokyo?\n"
66+
"2. I need to book a hotel room with a mountain view in Denver."
67+
),
68+
),
69+
Message(
70+
role=Role.USER,
71+
content=(
72+
"Intent name:\n\n"
73+
"Example Utterances:\n"
74+
"1. What is the weather like today?\n\n"
75+
"Please generate 2 more examples for the provided intent class."
76+
),
77+
),
78+
Message(
79+
role=Role.ASSISTANT,
80+
content=("1. Can you tell me the forecast for tomorrow?\n" "2. Is it going to rain this weekend?"),
81+
),
82+
Message(
83+
role=Role.USER,
84+
content=(
85+
"Intent name: Scheduling a Meeting\n\n"
86+
"Example Utterances:\n\n"
87+
"Please generate 3 more examples for the provided intent class."
88+
),
89+
),
90+
Message(
91+
role=Role.ASSISTANT,
92+
content=(
93+
"1. I need to schedule a meeting for next Tuesday.\n"
94+
"2. Can you set up a conference call for tomorrow afternoon?\n"
95+
"3. Please arrange a meeting with the marketing team next week."
96+
),
97+
),
98+
]
99+
100+
def __init__(
101+
self,
102+
dataset: Dataset,
103+
split: str,
104+
extra_instructions: str | None = None,
105+
max_sample_utterances: int | None = None,
106+
) -> None:
107+
"""Initialize."""
108+
if extra_instructions is None:
109+
extra_instructions = ""
110+
111+
self._messages = deepcopy(self.__messages)
112+
113+
msg = self._messages[0]
114+
msg["content"] = msg["content"].format(extra_instructions=extra_instructions)
115+
116+
self.dataset = dataset
117+
self.split = split
118+
self.max_sample_utterances = max_sample_utterances
119+
120+
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
121+
"""Generate additional examples for the provided intent class."""
122+
filtered_split = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] == intent_data.id)
123+
sample_utterances = filtered_split[Dataset.utterance_feature]
124+
if self.max_sample_utterances is not None:
125+
sample_utterances = random.sample(sample_utterances, k=self.max_sample_utterances)
126+
return [
127+
*self._messages,
128+
Message(
129+
role=Role.USER,
130+
content=f"Intent name: {intent_data.name}\n\n"
131+
f"Example Utterances:\n{sample_utterances}\n\n"
132+
f"Please generate {n_examples} more examples for the provided intent class.\n",
133+
),
134+
]

autointent/generation/utterances/basic/chat_template.yaml

Lines changed: 0 additions & 119 deletions
This file was deleted.

autointent/generation/utterances/basic/cli.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
"""CLI for basic utterance generator."""
22

3+
import logging
34
from argparse import ArgumentParser
45

56
from autointent import load_dataset
6-
from autointent.generation.utterances.basic.utterance_generator import LengthType, StyleType, UtteranceGenerator
7+
from autointent.generation.utterances.basic.utterance_generator import UtteranceGenerator
78
from autointent.generation.utterances.generator import Generator
89

10+
from .chat_template import SynthesizerChatTemplate
11+
12+
logging.basicConfig(level="INFO")
13+
logger = logging.getLogger(__name__)
14+
915

1016
def main() -> None:
1117
"""ClI endpoint."""
@@ -28,6 +34,7 @@ def main() -> None:
2834
default=None,
2935
help="Local path where to save result",
3036
)
37+
parser.add_argument("--split", type=str, default="train")
3138
parser.add_argument("--private", action="store_true", help="Publish privately if --output-repo option is used")
3239
parser.add_argument(
3340
"--n-generations",
@@ -41,37 +48,19 @@ def main() -> None:
4148
default=5,
4249
help="Number of utterances to use as an example for augmentation",
4350
)
44-
parser.add_argument(
45-
"--custom-instruction",
46-
type=str,
47-
action="append",
48-
help="Add extra instructions to default prompt."
49-
"You can use this argument multiple times to add multiple instructions",
50-
)
51-
parser.add_argument(
52-
"--length",
53-
choices=LengthType.__args__, # type: ignore[attr-defined]
54-
default="none",
55-
help="How to extend the prompt with length instruction",
56-
)
57-
parser.add_argument(
58-
"--style",
59-
choices=StyleType.__args__, # type: ignore[attr-defined]
60-
default="none",
61-
help="How to extend the prompt with style instruction",
62-
)
63-
parser.add_argument(
64-
"--same-punctuation",
65-
action="store_true",
66-
help="Whether to extend the prompt with punctuation instruction",
67-
)
6851
args = parser.parse_args()
6952

7053
dataset = load_dataset(args.input_path)
71-
generator = UtteranceGenerator(
72-
Generator(), args.custom_instruction or [], args.length, args.style, args.same_punctuation
73-
)
74-
generator.augment(dataset, n_generations=args.n_generations, max_sample_utterances=args.n_sample_utterances)
54+
template = SynthesizerChatTemplate(dataset, args.split, max_sample_utterances=args.n_sample_utterances)
55+
generator = UtteranceGenerator(Generator(), template)
56+
57+
n_before = len(dataset[args.split])
58+
new_samples = generator.augment(dataset, split_name=args.split, n_generations=args.n_generations)
59+
n_after = len(dataset[args.split])
60+
61+
logger.info("# samples before %s", n_before)
62+
logger.info("# samples generated %s", len(new_samples))
63+
logger.info("# samples after %s", n_after)
7564

7665
dataset.to_json(args.output_path)
7766

autointent/generation/utterances/basic/extra_instructions.json

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)