Skip to content

Commit 6b07e3b

Browse files
authored
Feat/augment multilabel datasets (#139)
* refactor chat templates code * fix typing * fix typing * add multilabel support * add oos support * big fix
1 parent 0694ebd commit 6b07e3b

File tree

9 files changed

+272
-240
lines changed

9 files changed

+272
-240
lines changed

autointent/generation/utterances/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .basic import SynthesizerChatTemplate, UtteranceGenerator
1+
from .basic import EnglishSynthesizerTemplate, RussianSynthesizerTemplate, UtteranceGenerator
22
from .evolution import (
33
AbstractEvolution,
44
ConcreteEvolution,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .chat_template import SynthesizerChatTemplate
1+
from .chat_templates import EnglishSynthesizerTemplate, RussianSynthesizerTemplate
22
from .utterance_generator import UtteranceGenerator
33

4-
__all__ = ["SynthesizerChatTemplate", "UtteranceGenerator"]
4+
__all__ = ["EnglishSynthesizerTemplate", "RussianSynthesizerTemplate", "UtteranceGenerator"]

autointent/generation/utterances/basic/chat_template.py

Lines changed: 0 additions & 229 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._base import BaseChatTemplate, BaseSynthesizerTemplate
2+
from ._synthesizer_en import EnglishSynthesizerTemplate
3+
from ._synthesizer_ru import RussianSynthesizerTemplate
4+
5+
__all__ = ["BaseChatTemplate", "BaseSynthesizerTemplate", "EnglishSynthesizerTemplate", "RussianSynthesizerTemplate"]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Base class for chat template for class-wise augmentation."""
2+
3+
import random
4+
from abc import ABC, abstractmethod
5+
from copy import deepcopy
6+
from typing import ClassVar
7+
8+
from autointent import Dataset
9+
from autointent.generation.utterances.schemas import Message, Role
10+
from autointent.schemas import Intent
11+
12+
13+
class BaseChatTemplate(ABC):
14+
"""Base class."""
15+
16+
@abstractmethod
17+
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
18+
"""Generate examples for this intent."""
19+
20+
21+
class BaseSynthesizerTemplate(BaseChatTemplate):
22+
"""Base chat template for generating additional examples for a given intent."""
23+
24+
_MESSAGES_TEMPLATE: ClassVar[list[Message]]
25+
_INTENT_NAME_LABEL: ClassVar[str]
26+
_EXAMPLE_UTTERANCES_LABEL: ClassVar[str]
27+
_GENERATE_INSTRUCTION: ClassVar[str]
28+
29+
def __init__(
30+
self,
31+
dataset: Dataset,
32+
split: str,
33+
extra_instructions: str | None = None,
34+
max_sample_utterances: int | None = None,
35+
) -> None:
36+
"""Initialize the chat template with dataset, split, and optional instructions."""
37+
if extra_instructions is None:
38+
extra_instructions = ""
39+
40+
self._messages = deepcopy(self._MESSAGES_TEMPLATE)
41+
42+
if self._messages:
43+
self._messages[0]["content"] = self._messages[0]["content"].format(extra_instructions=extra_instructions)
44+
45+
self.dataset = dataset
46+
self.split = split
47+
self.max_sample_utterances = max_sample_utterances
48+
49+
def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
50+
"""Generate a list of messages to request additional examples for the given intent."""
51+
in_domain_samples = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] is not None)
52+
if self.dataset.multilabel:
53+
filter_fn = lambda sample: sample[Dataset.label_feature][intent_data.id] == 1 # noqa: E731
54+
else:
55+
filter_fn = lambda sample: sample[Dataset.label_feature] == intent_data.id # noqa: E731
56+
57+
filtered_split = in_domain_samples.filter(filter_fn)
58+
sample_utterances = filtered_split[Dataset.utterance_feature]
59+
60+
if self.max_sample_utterances is not None and len(sample_utterances) > self.max_sample_utterances:
61+
sample_utterances = random.sample(sample_utterances, k=self.max_sample_utterances)
62+
63+
return [
64+
*self._messages,
65+
self._create_final_message(intent_data, n_examples, sample_utterances),
66+
]
67+
68+
def _create_final_message(self, intent_data: Intent, n_examples: int, sample_utterances: list[str]) -> Message:
69+
content = f"{self._INTENT_NAME_LABEL}: {intent_data.name}\n\n" f"{self._EXAMPLE_UTTERANCES_LABEL}:\n"
70+
71+
if sample_utterances:
72+
numbered_utterances = "\n".join(f"{i+1}. {utt}" for i, utt in enumerate(sample_utterances))
73+
content += numbered_utterances + "\n\n"
74+
75+
content += self._GENERATE_INSTRUCTION.format(n_examples=n_examples)
76+
return Message(role=Role.USER, content=content)

0 commit comments

Comments
 (0)