Skip to content

Commit 7338d7e

Browse files
authored
Refactor/move to our dataset class (#100)
* refactor basic utterance generator * make `load_dataset` utility public * polish `load_dataset` utility * move basic utterance generator to `Dataset` * refactor cli for basic utterance generator * refactor evolutions module * some bug fix in basic utterance generation * some bug fix in evolutionary augmentations * refactor `Generator` and fix codestyle * fix typing
1 parent ae87666 commit 7338d7e

File tree

15 files changed

+277
-197
lines changed

15 files changed

+277
-197
lines changed

autointent/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ._embedder import Embedder
44
from ._dataset import Dataset
55
from ._hash import Hasher
6-
from .context import Context
6+
from .context import Context, load_dataset
77
from ._pipeline import Pipeline
88

9-
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline"]
9+
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline", "load_dataset"]

autointent/_dataset/_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ def to_json(self, filepath: str | Path) -> None:
144144
145145
:param filepath: The path to the file where the JSON data will be saved.
146146
"""
147-
with Path(filepath).open("w") as file:
147+
path = Path(filepath)
148+
if not path.parent.exists():
149+
path.parent.mkdir(parents=True)
150+
with path.open("w") as file:
148151
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
149152

150153
def push_to_hub(self, repo_id: str, private: bool = False) -> None:

autointent/context/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Core utilities for auto ML features."""
22

33
from ._context import Context
4+
from ._utils import load_dataset
45

5-
__all__ = ["Context"]
6+
__all__ = ["Context", "load_dataset"]

autointent/context/_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
VectorIndexConfig,
1717
)
1818

19-
from ._utils import NumpyEncoder, load_data
19+
from ._utils import NumpyEncoder, load_dataset
2020
from .data_handler import DataHandler
2121
from .optimization_info import OptimizationInfo
2222
from .vector_index_client import VectorIndexClient
@@ -81,7 +81,7 @@ def configure_data(self, config: DataConfig) -> None:
8181
:param config: Configuration for the data handling process.
8282
"""
8383
self.data_handler = DataHandler(
84-
dataset=load_data(config.train_path),
84+
dataset=load_dataset(config.train_path),
8585
random_seed=self.seed,
8686
force_multilabel=config.force_multilabel,
8787
)

autointent/context/_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN
4040
return super().default(obj)
4141

4242

43-
def load_data(filepath: str | Path) -> Dataset:
43+
def load_dataset(path: str | Path) -> Dataset:
4444
"""
45-
Load data from a specified path or use default sample data.
45+
Load data from a specified path or use default sample data or load from hugging face hub.
4646
4747
This function loads a dataset from a JSON file or retrieves sample data
4848
included with the `autointent` package for default multiclass or multilabel
@@ -53,10 +53,10 @@ def load_data(filepath: str | Path) -> Dataset:
5353
- "default-multilabel": Loads sample multilabel dataset.
5454
:return: A `Dataset` object containing the loaded data.
5555
"""
56-
if filepath == "default-multiclass":
56+
if path == "default-multiclass":
5757
return Dataset.from_hub("AutoIntent/clinc150_subset")
58-
if filepath == "default-multilabel":
58+
if path == "default-multilabel":
5959
return Dataset.from_hub("AutoIntent/clinc150_subset").to_multilabel()
60-
if not Path(filepath).exists():
61-
return Dataset.from_hub(str(filepath))
62-
return Dataset.from_json(filepath)
60+
if not Path(path).exists():
61+
return Dataset.from_hub(str(path))
62+
return Dataset.from_json(path)

autointent/generation/intents/__init__.py

Whitespace-only changes.

autointent/generation/utterances/basic/generate.py renamed to autointent/generation/utterances/basic/cli.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,46 @@
1-
import json
2-
import os
1+
"""CLI for basic utterance generator."""
2+
33
from argparse import ArgumentParser
4-
from typing import Any
54

5+
from autointent import load_dataset
66
from autointent.generation.utterances.basic.utterance_generator import LengthType, StyleType, UtteranceGenerator
77
from autointent.generation.utterances.generator import Generator
88

99

10-
def read_json_dataset(file_path: os.PathLike):
11-
with open(file_path) as file:
12-
return json.load(file)
13-
14-
15-
def save_json_dataset(file_path: os.PathLike, intents: list[dict[str, Any]]):
16-
dirname = os.path.dirname(file_path)
17-
if not os.path.exists(dirname):
18-
os.makedirs(dirname)
19-
with open(file_path, "w") as file:
20-
json.dump(intents, file, indent=4, ensure_ascii=False)
21-
22-
23-
def main():
10+
def main() -> None:
11+
"""ClI endpoint."""
2412
parser = ArgumentParser()
2513
parser.add_argument(
2614
"--input-path",
2715
type=str,
2816
required=True,
29-
help="Path to json with intent records",
17+
help="Path to json or hugging face repo with dataset",
3018
)
3119
parser.add_argument(
3220
"--output-path",
3321
type=str,
3422
required=True,
35-
help="Where to save result",
23+
help="Local path where to save result",
3624
)
3725
parser.add_argument(
38-
"--n-shots",
26+
"--output-repo",
27+
type=str,
28+
default=None,
29+
help="Local path where to save result",
30+
)
31+
parser.add_argument("--private", action="store_true", help="Publish privately if --output-repo option is used")
32+
parser.add_argument(
33+
"--n-generations",
3934
type=int,
40-
required=True,
35+
default=5,
4136
help="Number of utterances to generate for each intent",
4237
)
38+
parser.add_argument(
39+
"--n-sample-utterances",
40+
type=int,
41+
default=5,
42+
help="Number of utterances to use as an example for augmentation",
43+
)
4344
parser.add_argument(
4445
"--custom-instruction",
4546
type=str,
@@ -49,13 +50,13 @@ def main():
4950
)
5051
parser.add_argument(
5152
"--length",
52-
choices=LengthType.__args__,
53+
choices=LengthType.__args__, # type: ignore[attr-defined]
5354
default="none",
5455
help="How to extend the prompt with length instruction",
5556
)
5657
parser.add_argument(
5758
"--style",
58-
choices=StyleType.__args__,
59+
choices=StyleType.__args__, # type: ignore[attr-defined]
5960
default="none",
6061
help="How to extend the prompt with style instruction",
6162
)
@@ -66,13 +67,16 @@ def main():
6667
)
6768
args = parser.parse_args()
6869

69-
intents = read_json_dataset(args.input_path)
70+
dataset = load_dataset(args.input_path)
71+
generator = UtteranceGenerator(
72+
Generator(), args.custom_instruction or [], args.length, args.style, args.same_punctuation
73+
)
74+
generator.augment(dataset, n_generations=args.n_generations, max_sample_utterances=args.n_sample_utterances)
7075

71-
generator = UtteranceGenerator(Generator(), args.custom_instruction, args.length, args.style, args.same_punctuation)
72-
for intent_record in intents:
73-
generator(intent_record, args.n_shots, inplace=True)
76+
dataset.to_json(args.output_path)
7477

75-
save_json_dataset(args.output_path, intents)
78+
if args.output_repo is not None:
79+
dataset.push_to_hub(args.output_repo, private=args.private)
7680

7781

7882
if __name__ == "__main__":

autointent/generation/utterances/basic/utterance_generator.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,116 @@
1+
"""Basic generation of new utterances from existing ones."""
2+
13
import importlib.resources as ires
24
import json
5+
import random
36
from typing import Any, Literal
47

58
import yaml
9+
from datasets import Dataset as HFDataset
10+
from datasets import concatenate_datasets
611

12+
from autointent import Dataset
13+
from autointent.custom_types import Split
714
from autointent.generation.utterances.generator import Generator
8-
from autointent.generation.utterances.utils import safe_format
15+
from autointent.generation.utterances.utils import safe_format # type: ignore[attr-defined]
16+
from autointent.schemas import Sample
917

1018
LengthType = Literal["none", "same", "longer", "shorter"]
1119
StyleType = Literal["none", "formal", "informal", "playful"]
1220

1321

1422
class UtteranceGenerator:
23+
"""
24+
Basic generation of new utterances from existing ones.
25+
26+
This augmentation method simply prompts LLM to look at existing examples
27+
and generate similar. Additionaly it can consider some aspects of style,
28+
punctuation and length of the desired generations.
29+
"""
30+
1531
def __init__(
1632
self,
1733
generator: Generator,
1834
custom_instruction: list[str],
1935
length: LengthType,
2036
style: StyleType,
2137
same_punctuation: bool,
22-
):
38+
) -> None:
39+
"""Initialize."""
2340
self.generator = generator
24-
prompt_template_yaml = load_prompt()
25-
self.prompt_template_yaml = add_extra_instructions(
41+
prompt_template_yaml = _load_prompt()
42+
self.prompt_template_yaml = _add_extra_instructions(
2643
prompt_template_yaml,
2744
custom_instruction,
2845
length,
2946
style,
3047
same_punctuation,
3148
)
3249

33-
def _generate(self, intent_name: str, example_utterances: list[str], n_examples: int) -> list[str]:
50+
def __call__(self, intent_name: str, example_utterances: list[str], n_generations: int) -> list[str]:
51+
"""Generate new utterances."""
3452
messages_yaml = safe_format(
3553
self.prompt_template_yaml,
3654
intent_name=intent_name,
37-
example_utterances=format_utterances(example_utterances),
38-
n_examples=n_examples,
55+
example_utterances=_format_utterances(example_utterances),
56+
n_examples=n_generations,
3957
)
4058
messages = yaml.safe_load(messages_yaml)
4159
response_text = self.generator.get_chat_completion(messages)
42-
return extract_utterances(response_text)
60+
return _extract_utterances(response_text)
4361

44-
def __call__(self, intent_record: dict[str, Any], n_examples: int, inplace: bool = True) -> list[str]:
45-
intent_name = intent_record.get("intent_name", "")
46-
example_utterances = intent_record.get("sample_utterances", [])
47-
res_utterances = self._generate(intent_name, example_utterances, n_examples)
48-
if inplace:
49-
intent_record["sample_utterances"] = intent_record.get("sample_utterances", []) + res_utterances
50-
return res_utterances
51-
52-
53-
def load_prompt():
54-
with ires.files("autointent.generation.basic").joinpath("chat_template.yaml").open() as file:
62+
def augment(
63+
self,
64+
dataset: Dataset,
65+
split_name: str = Split.TRAIN,
66+
n_generations: int = 5,
67+
max_sample_utterances: int = 5,
68+
update_split: bool = True,
69+
) -> list[Sample]:
70+
"""
71+
Augment some split of dataset.
72+
73+
Note that for now it supports only single-label datasets.
74+
"""
75+
original_split = dataset[split_name]
76+
new_samples = []
77+
for intent in dataset.intents:
78+
filtered_split = original_split.filter(lambda sample, id=intent.id: sample[Dataset.label_feature] == id)
79+
sample_utterances = filtered_split[Dataset.utterance_feature]
80+
if max_sample_utterances is not None:
81+
sample_utterances = random.sample(sample_utterances, k=max_sample_utterances)
82+
generated_utterances = self(
83+
intent_name=intent.name or "",
84+
example_utterances=sample_utterances,
85+
n_generations=n_generations,
86+
)
87+
new_samples.extend(
88+
[{Dataset.label_feature: intent.id, Dataset.utterance_feature: ut} for ut in generated_utterances]
89+
)
90+
if update_split:
91+
generated_split = HFDataset.from_list(new_samples)
92+
dataset[split_name] = concatenate_datasets([original_split, generated_split])
93+
return [Sample(**sample) for sample in new_samples]
94+
95+
96+
def _load_prompt() -> str:
97+
with ires.files("autointent.generation.utterances.basic").joinpath("chat_template.yaml").open() as file:
5598
return file.read()
5699

57100

58-
def load_extra_instructions():
59-
with ires.files("autointent.generation.basic").joinpath("extra_instructions.json").open() as file:
60-
return json.load(file)
101+
def _load_extra_instructions() -> dict[str, Any]:
102+
with ires.files("autointent.generation.utterances.basic").joinpath("extra_instructions.json").open() as file:
103+
return json.load(file) # type: ignore[no-any-return]
61104

62105

63-
def add_extra_instructions(
106+
def _add_extra_instructions(
64107
prompt_template_yaml: str,
65108
custom_instruction: list[str],
66109
length: LengthType,
67110
style: StyleType,
68111
same_punctuation: bool,
69112
) -> str:
70-
instructions = load_extra_instructions()
113+
instructions = _load_extra_instructions()
71114

72115
extra_instructions = []
73116
if length != "none":
@@ -80,40 +123,29 @@ def add_extra_instructions(
80123
extra_instructions.extend(custom_instruction)
81124

82125
parsed_extra_instructions = "\n ".join([f"- {s}" for s in extra_instructions])
83-
return safe_format(prompt_template_yaml, extra_instructions=parsed_extra_instructions)
126+
return safe_format(prompt_template_yaml, extra_instructions=parsed_extra_instructions) # type: ignore[no-any-return]
84127

85128

86-
def format_utterances(utterances: list[str]) -> str:
129+
def _format_utterances(utterances: list[str]) -> str:
87130
"""
88-
Return
89-
---
90-
str of the following format:
131+
Convert given utterances into string that is ready to insert into prompt.
91132
92-
```
133+
Given list of utterances, the output string is returned in the following format:
134+
.. code-block::
93135
1. I want to order a large pepperoni pizza.
94136
2. Can I get a medium cheese pizza with extra olives?
95137
3. Please deliver a small veggie pizza to my address.
96-
```
97138
98-
Note
99-
---
100-
tab is inserted before each line because of how yaml processes multi-line fields
139+
Note that tab is inserted before each line because of how yaml processes multi-line fields.
101140
"""
102141
return "\n ".join(f"{i}. {ut}" for i, ut in enumerate(utterances))
103142

104143

105-
def extract_utterances(response_text: str) -> list[str]:
144+
def _extract_utterances(response_text: str) -> list[str]:
106145
"""
107-
Input
108-
---
109-
str of the following format:
110-
111-
```
112-
1. I want to order a large pepperoni pizza.
113-
2. Can I get a medium cheese pizza with extra olives?
114-
3. Please deliver a small veggie pizza to my address.
115-
```
146+
Parse LLM output.
116147
148+
Inverse function to :py:func:`_format_utterances`.
117149
"""
118150
raw_utterances = response_text.split("\n")
119151
# remove enumeration
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
I want you to act as a rewriter.
22
You will be provided with an utterance and the topic (name of intent class) of the utterance.
3-
You MUST complicate the utterance using the following method:
3+
You need to complicate the utterance using the following method:

0 commit comments

Comments
 (0)