Skip to content

Commit d2ac6e1

Browse files
authored
Refactor augmentations (#147)
* update chat template structure * update chat template structure * refactor chat templates * lint * fix imports
1 parent 6b07e3b commit d2ac6e1

File tree

19 files changed

+51
-117
lines changed

19 files changed

+51
-117
lines changed

autointent/_dataset/_validation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _validate_classes(self, splits: list[list[Sample]]) -> int:
101101
)
102102
raise ValueError(message)
103103
if not n_classes[0]:
104-
message = "Number of classes is zero or undefined. " "Ensure at least one class is present in the splits."
104+
message = "Number of classes is zero or undefined. Ensure at least one class is present in the splits."
105105
raise ValueError(message)
106106
return n_classes[0]
107107

@@ -120,8 +120,7 @@ def _validate_intents(self, n_classes: int) -> "DatasetReader":
120120
intent_ids = [intent.id for intent in self.intents]
121121
if intent_ids != list(range(len(self.intents))):
122122
message = (
123-
f"Invalid intent IDs. Expected sequential IDs from 0 to {len(self.intents) - 1}, "
124-
f"but got {intent_ids}."
123+
f"Invalid intent IDs. Expected sequential IDs from 0 to {len(self.intents) - 1}, but got {intent_ids}."
125124
)
126125
raise ValueError(message)
127126
return self

autointent/context/data_handler/_data_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _split_cv(self) -> None:
236236
random_seed=self.random_seed,
237237
allow_oos_in_train=True,
238238
)
239-
self.dataset[f"{Split.TRAIN}_{self.config.n_folds-1}"] = self.dataset.pop(Split.TRAIN)
239+
self.dataset[f"{Split.TRAIN}_{self.config.n_folds - 1}"] = self.dataset.pop(Split.TRAIN)
240240

241241
def _split_validation_from_train(self, size: float) -> None:
242242
if Split.TRAIN in self.dataset:

autointent/generation/utterances/basic/chat_templates/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]:
6666
]
6767

6868
def _create_final_message(self, intent_data: Intent, n_examples: int, sample_utterances: list[str]) -> Message:
69-
content = f"{self._INTENT_NAME_LABEL}: {intent_data.name}\n\n" f"{self._EXAMPLE_UTTERANCES_LABEL}:\n"
69+
content = f"{self._INTENT_NAME_LABEL}: {intent_data.name}\n\n{self._EXAMPLE_UTTERANCES_LABEL}:\n"
7070

7171
if sample_utterances:
72-
numbered_utterances = "\n".join(f"{i+1}. {utt}" for i, utt in enumerate(sample_utterances))
72+
numbered_utterances = "\n".join(f"{i + 1}. {utt}" for i, utt in enumerate(sample_utterances))
7373
content += numbered_utterances + "\n\n"
7474

7575
content += self._GENERATE_INSTRUCTION.format(n_examples=n_examples)

autointent/generation/utterances/basic/chat_templates/_synthesizer_en.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class EnglishSynthesizerTemplate(BaseSynthesizerTemplate):
7070
),
7171
Message(
7272
role=Role.ASSISTANT,
73-
content=("1. Can you tell me the forecast for tomorrow?\n" "2. Is it going to rain this weekend?"),
73+
content="1. Can you tell me the forecast for tomorrow?\n2. Is it going to rain this weekend?",
7474
),
7575
Message(
7676
role=Role.USER,

autointent/generation/utterances/basic/chat_templates/_synthesizer_ru.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class RussianSynthesizerTemplate(BaseSynthesizerTemplate):
5353
),
5454
Message(
5555
role=Role.ASSISTANT,
56-
content=("1. Забронируйте люкс в Санкт-Петербурге на выходные\n" "2. Ищу номер с видом на море в Сочи"),
56+
content=("1. Забронируйте люкс в Санкт-Петербурге на выходные\n2. Ищу номер с видом на море в Сочи"),
5757
),
5858
Message(
5959
role=Role.USER,
@@ -66,7 +66,7 @@ class RussianSynthesizerTemplate(BaseSynthesizerTemplate):
6666
),
6767
Message(
6868
role=Role.ASSISTANT,
69-
content=("1. Какой прогноз на завтра?\n" "2. Будет ли дождь в субботу?"),
69+
content=("1. Какой прогноз на завтра?\n2. Будет ли дождь в субботу?"),
7070
),
7171
Message(
7272
role=Role.USER,

autointent/generation/utterances/basic/utterance_generator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""Basic generation of new utterances from existing ones."""
22

33
import asyncio
4-
from collections.abc import Callable
54

65
from datasets import Dataset as HFDataset
76
from datasets import concatenate_datasets
87

98
from autointent import Dataset
109
from autointent.custom_types import Split
10+
from autointent.generation.utterances.basic.chat_templates import BaseSynthesizerTemplate
1111
from autointent.generation.utterances.generator import Generator
12-
from autointent.generation.utterances.schemas import Message
1312
from autointent.schemas import Intent, Sample
1413

1514

@@ -22,9 +21,7 @@ class UtteranceGenerator:
2221
punctuation, and length of the desired generations.
2322
"""
2423

25-
def __init__(
26-
self, generator: Generator, prompt_maker: Callable[[Intent, int], list[Message]], async_mode: bool = False
27-
) -> None:
24+
def __init__(self, generator: Generator, prompt_maker: BaseSynthesizerTemplate, async_mode: bool = False) -> None:
2825
"""Initialize."""
2926
self.generator = generator
3027
self.prompt_maker = prompt_maker

autointent/generation/utterances/evolution/chat_templates/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from .informal import InformalEvolution
88
from .reasoning import ReasoningEvolution
99

10+
EVOLUTION_NAMES = [evolution.name for evolution in EvolutionChatTemplate.__subclasses__()]
11+
12+
EVOLUTION_MAPPING = {evolution.name: evolution() for evolution in EvolutionChatTemplate.__subclasses__()}
13+
1014
__all__ = [
15+
"EVOLUTION_MAPPING",
16+
"EVOLUTION_NAMES",
1117
"AbstractEvolution",
1218
"ConcreteEvolution",
1319
"EvolutionChatTemplate",

autointent/generation/utterances/evolution/chat_templates/abstract.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import ClassVar
44

55
from autointent.generation.utterances.schemas import Message, Role
6-
from autointent.schemas import Intent
76

87
from .base import EvolutionChatTemplate
98

109

1110
class AbstractEvolution(EvolutionChatTemplate):
1211
"""Chat template for evolution augmentation via abstraction."""
1312

13+
name = "abstract"
1414
_messages: ClassVar[list[Message]] = [
1515
Message(
1616
role=Role.USER,
@@ -36,10 +36,3 @@ class AbstractEvolution(EvolutionChatTemplate):
3636
),
3737
Message(role=Role.ASSISTANT, content="I'm having trouble with my laptop."),
3838
]
39-
40-
def __call__(self, utterance: str, intent_data: Intent) -> list[Message]:
41-
"""Make chat to complete."""
42-
return [
43-
*self._messages,
44-
Message(role=Role.USER, content=f"Intent name: {intent_data.name or ''}\nUtterance: {utterance}"),
45-
]
Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
"""Base class for chat templates for evolution augmentation."""
22

3-
from abc import ABC, abstractmethod
3+
from typing import ClassVar
44

5-
from autointent.generation.utterances.schemas import Message
5+
from autointent.generation.utterances.schemas import Message, Role
66
from autointent.schemas import Intent
77

88

9-
class EvolutionChatTemplate(ABC):
9+
class EvolutionChatTemplate:
1010
"""Base class for chat templates for evolution augmentation."""
1111

12-
@abstractmethod
12+
_messages: ClassVar[list[Message]]
13+
name: str
14+
1315
def __call__(self, utterance: str, intent_data: Intent) -> list[Message]:
1416
"""Make a chat to complete by LLM."""
17+
invoke_message = Message(
18+
role=Role.USER,
19+
content=f"Intent name: {intent_data.name or ''}\nUtterance: {utterance}",
20+
)
21+
return [*self._messages, invoke_message]

autointent/generation/utterances/evolution/chat_templates/concrete.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from typing import ClassVar
44

55
from autointent.generation.utterances.schemas import Message, Role
6-
from autointent.schemas import Intent
76

87
from .base import EvolutionChatTemplate
98

109

1110
class ConcreteEvolution(EvolutionChatTemplate):
1211
"""Chat template for evolution augmentation via concretizing."""
1312

13+
name = "concrete"
14+
1415
_messages: ClassVar[list[Message]] = [
1516
Message(
1617
role=Role.USER,
@@ -29,14 +30,7 @@ class ConcreteEvolution(EvolutionChatTemplate):
2930
Message(role=Role.ASSISTANT, content="I want to reserve a table for 4 persons at 9 pm."),
3031
Message(
3132
role=Role.USER,
32-
content=("Intent name: requesting technical support\n" "Utterance: I'm having trouble with my laptop."),
33+
content="Intent name: requesting technical support\nUtterance: I'm having trouble with my laptop.",
3334
),
3435
Message(role=Role.ASSISTANT, content="My laptop is constantly rebooting and overheating."),
3536
]
36-
37-
def __call__(self, utterance: str, intent_data: Intent) -> list[Message]:
38-
"""Make chat to complete."""
39-
return [
40-
*self._messages,
41-
Message(role=Role.USER, content=f"Intent name: {intent_data.name or ''}\nUtterance: {utterance}"),
42-
]

0 commit comments

Comments
 (0)