1+ """Basic generation of new utterances from existing ones."""
2+
13import importlib .resources as ires
24import json
5+ import random
36from typing import Any , Literal
47
58import yaml
9+ from datasets import Dataset as HFDataset
10+ from datasets import concatenate_datasets
611
12+ from autointent import Dataset
13+ from autointent .custom_types import Split
714from autointent .generation .utterances .generator import Generator
8- from autointent .generation .utterances .utils import safe_format
15+ from autointent .generation .utterances .utils import safe_format # type: ignore[attr-defined]
16+ from autointent .schemas import Sample
917
1018LengthType = Literal ["none" , "same" , "longer" , "shorter" ]
1119StyleType = Literal ["none" , "formal" , "informal" , "playful" ]
1220
1321
1422class UtteranceGenerator :
23+ """
24+ Basic generation of new utterances from existing ones.
25+
26+ This augmentation method simply prompts LLM to look at existing examples
27+ and generate similar. Additionaly it can consider some aspects of style,
28+ punctuation and length of the desired generations.
29+ """
30+
1531 def __init__ (
1632 self ,
1733 generator : Generator ,
1834 custom_instruction : list [str ],
1935 length : LengthType ,
2036 style : StyleType ,
2137 same_punctuation : bool ,
22- ):
38+ ) -> None :
39+ """Initialize."""
2340 self .generator = generator
24- prompt_template_yaml = load_prompt ()
25- self .prompt_template_yaml = add_extra_instructions (
41+ prompt_template_yaml = _load_prompt ()
42+ self .prompt_template_yaml = _add_extra_instructions (
2643 prompt_template_yaml ,
2744 custom_instruction ,
2845 length ,
2946 style ,
3047 same_punctuation ,
3148 )
3249
33- def _generate (self , intent_name : str , example_utterances : list [str ], n_examples : int ) -> list [str ]:
50+ def __call__ (self , intent_name : str , example_utterances : list [str ], n_generations : int ) -> list [str ]:
51+ """Generate new utterances."""
3452 messages_yaml = safe_format (
3553 self .prompt_template_yaml ,
3654 intent_name = intent_name ,
37- example_utterances = format_utterances (example_utterances ),
38- n_examples = n_examples ,
55+ example_utterances = _format_utterances (example_utterances ),
56+ n_examples = n_generations ,
3957 )
4058 messages = yaml .safe_load (messages_yaml )
4159 response_text = self .generator .get_chat_completion (messages )
42- return extract_utterances (response_text )
60+ return _extract_utterances (response_text )
4361
44- def __call__ (self , intent_record : dict [str , Any ], n_examples : int , inplace : bool = True ) -> list [str ]:
45- intent_name = intent_record .get ("intent_name" , "" )
46- example_utterances = intent_record .get ("sample_utterances" , [])
47- res_utterances = self ._generate (intent_name , example_utterances , n_examples )
48- if inplace :
49- intent_record ["sample_utterances" ] = intent_record .get ("sample_utterances" , []) + res_utterances
50- return res_utterances
51-
52-
53- def load_prompt ():
54- with ires .files ("autointent.generation.basic" ).joinpath ("chat_template.yaml" ).open () as file :
62+ def augment (
63+ self ,
64+ dataset : Dataset ,
65+ split_name : str = Split .TRAIN ,
66+ n_generations : int = 5 ,
67+ max_sample_utterances : int = 5 ,
68+ update_split : bool = True ,
69+ ) -> list [Sample ]:
70+ """
71+ Augment some split of dataset.
72+
73+ Note that for now it supports only single-label datasets.
74+ """
75+ original_split = dataset [split_name ]
76+ new_samples = []
77+ for intent in dataset .intents :
78+ filtered_split = original_split .filter (lambda sample , id = intent .id : sample [Dataset .label_feature ] == id )
79+ sample_utterances = filtered_split [Dataset .utterance_feature ]
80+ if max_sample_utterances is not None :
81+ sample_utterances = random .sample (sample_utterances , k = max_sample_utterances )
82+ generated_utterances = self (
83+ intent_name = intent .name or "" ,
84+ example_utterances = sample_utterances ,
85+ n_generations = n_generations ,
86+ )
87+ new_samples .extend (
88+ [{Dataset .label_feature : intent .id , Dataset .utterance_feature : ut } for ut in generated_utterances ]
89+ )
90+ if update_split :
91+ generated_split = HFDataset .from_list (new_samples )
92+ dataset [split_name ] = concatenate_datasets ([original_split , generated_split ])
93+ return [Sample (** sample ) for sample in new_samples ]
94+
95+
96+ def _load_prompt () -> str :
97+ with ires .files ("autointent.generation.utterances.basic" ).joinpath ("chat_template.yaml" ).open () as file :
5598 return file .read ()
5699
57100
58- def load_extra_instructions () :
59- with ires .files ("autointent.generation.basic" ).joinpath ("extra_instructions.json" ).open () as file :
60- return json .load (file )
101+ def _load_extra_instructions () -> dict [ str , Any ] :
102+ with ires .files ("autointent.generation.utterances. basic" ).joinpath ("extra_instructions.json" ).open () as file :
103+ return json .load (file ) # type: ignore[no-any-return]
61104
62105
63- def add_extra_instructions (
106+ def _add_extra_instructions (
64107 prompt_template_yaml : str ,
65108 custom_instruction : list [str ],
66109 length : LengthType ,
67110 style : StyleType ,
68111 same_punctuation : bool ,
69112) -> str :
70- instructions = load_extra_instructions ()
113+ instructions = _load_extra_instructions ()
71114
72115 extra_instructions = []
73116 if length != "none" :
@@ -80,40 +123,29 @@ def add_extra_instructions(
80123 extra_instructions .extend (custom_instruction )
81124
82125 parsed_extra_instructions = "\n " .join ([f"- { s } " for s in extra_instructions ])
83- return safe_format (prompt_template_yaml , extra_instructions = parsed_extra_instructions )
126+ return safe_format (prompt_template_yaml , extra_instructions = parsed_extra_instructions ) # type: ignore[no-any-return]
84127
85128
86- def format_utterances (utterances : list [str ]) -> str :
129+ def _format_utterances (utterances : list [str ]) -> str :
87130 """
88- Return
89- ---
90- str of the following format:
131+ Convert given utterances into string that is ready to insert into prompt.
91132
92- ```
133+ Given list of utterances, the output string is returned in the following format:
134+ .. code-block::
93135 1. I want to order a large pepperoni pizza.
94136 2. Can I get a medium cheese pizza with extra olives?
95137 3. Please deliver a small veggie pizza to my address.
96- ```
97138
98- Note
99- ---
100- tab is inserted before each line because of how yaml processes multi-line fields
139+ Note that tab is inserted before each line because of how yaml processes multi-line fields.
101140 """
102141 return "\n " .join (f"{ i } . { ut } " for i , ut in enumerate (utterances ))
103142
104143
105- def extract_utterances (response_text : str ) -> list [str ]:
144+ def _extract_utterances (response_text : str ) -> list [str ]:
106145 """
107- Input
108- ---
109- str of the following format:
110-
111- ```
112- 1. I want to order a large pepperoni pizza.
113- 2. Can I get a medium cheese pizza with extra olives?
114- 3. Please deliver a small veggie pizza to my address.
115- ```
146+ Parse LLM output.
116147
148+ Inverse function to :py:func:`_format_utterances`.
117149 """
118150 raw_utterances = response_text .split ("\n " )
119151 # remove enumeration
0 commit comments