|
| 1 | +"""Base class for chat template for class-wise augmentation.""" |
| 2 | + |
| 3 | +import random |
| 4 | +from abc import ABC, abstractmethod |
| 5 | +from copy import deepcopy |
| 6 | +from typing import ClassVar |
| 7 | + |
| 8 | +from autointent import Dataset |
| 9 | +from autointent.generation.utterances.schemas import Message, Role |
| 10 | +from autointent.schemas import Intent |
| 11 | + |
| 12 | + |
| 13 | +class BaseChatTemplate(ABC): |
| 14 | + """Base class.""" |
| 15 | + |
| 16 | + @abstractmethod |
| 17 | + def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]: |
| 18 | + """Generate examples for this intent.""" |
| 19 | + |
| 20 | + |
| 21 | +class BaseSynthesizerTemplate(BaseChatTemplate): |
| 22 | + """Base chat template for generating additional examples for a given intent.""" |
| 23 | + |
| 24 | + _MESSAGES_TEMPLATE: ClassVar[list[Message]] |
| 25 | + _INTENT_NAME_LABEL: ClassVar[str] |
| 26 | + _EXAMPLE_UTTERANCES_LABEL: ClassVar[str] |
| 27 | + _GENERATE_INSTRUCTION: ClassVar[str] |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + dataset: Dataset, |
| 32 | + split: str, |
| 33 | + extra_instructions: str | None = None, |
| 34 | + max_sample_utterances: int | None = None, |
| 35 | + ) -> None: |
| 36 | + """Initialize the chat template with dataset, split, and optional instructions.""" |
| 37 | + if extra_instructions is None: |
| 38 | + extra_instructions = "" |
| 39 | + |
| 40 | + self._messages = deepcopy(self._MESSAGES_TEMPLATE) |
| 41 | + |
| 42 | + if self._messages: |
| 43 | + self._messages[0]["content"] = self._messages[0]["content"].format(extra_instructions=extra_instructions) |
| 44 | + |
| 45 | + self.dataset = dataset |
| 46 | + self.split = split |
| 47 | + self.max_sample_utterances = max_sample_utterances |
| 48 | + |
| 49 | + def __call__(self, intent_data: Intent, n_examples: int) -> list[Message]: |
| 50 | + """Generate a list of messages to request additional examples for the given intent.""" |
| 51 | + in_domain_samples = self.dataset[self.split].filter(lambda sample: sample[Dataset.label_feature] is not None) |
| 52 | + if self.dataset.multilabel: |
| 53 | + filter_fn = lambda sample: sample[Dataset.label_feature][intent_data.id] == 1 # noqa: E731 |
| 54 | + else: |
| 55 | + filter_fn = lambda sample: sample[Dataset.label_feature] == intent_data.id # noqa: E731 |
| 56 | + |
| 57 | + filtered_split = in_domain_samples.filter(filter_fn) |
| 58 | + sample_utterances = filtered_split[Dataset.utterance_feature] |
| 59 | + |
| 60 | + if self.max_sample_utterances is not None and len(sample_utterances) > self.max_sample_utterances: |
| 61 | + sample_utterances = random.sample(sample_utterances, k=self.max_sample_utterances) |
| 62 | + |
| 63 | + return [ |
| 64 | + *self._messages, |
| 65 | + self._create_final_message(intent_data, n_examples, sample_utterances), |
| 66 | + ] |
| 67 | + |
| 68 | + def _create_final_message(self, intent_data: Intent, n_examples: int, sample_utterances: list[str]) -> Message: |
| 69 | + content = f"{self._INTENT_NAME_LABEL}: {intent_data.name}\n\n" f"{self._EXAMPLE_UTTERANCES_LABEL}:\n" |
| 70 | + |
| 71 | + if sample_utterances: |
| 72 | + numbered_utterances = "\n".join(f"{i+1}. {utt}" for i, utt in enumerate(sample_utterances)) |
| 73 | + content += numbered_utterances + "\n\n" |
| 74 | + |
| 75 | + content += self._GENERATE_INSTRUCTION.format(n_examples=n_examples) |
| 76 | + return Message(role=Role.USER, content=content) |
0 commit comments