Skip to content

Commit b1e27e1

Browse files
authored
update description generator (#237)
* update description generator * fix test * fix typing * fix tests
1 parent 66c1a80 commit b1e27e1

File tree

5 files changed

+91
-144
lines changed

5 files changed

+91
-144
lines changed

autointent/generation/_generator.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from pathlib import Path
77
from textwrap import dedent
8-
from typing import Any, ClassVar, Literal, TypedDict, TypeVar
8+
from typing import Any, Literal, TypedDict, TypeVar
99

1010
import openai
1111
from dotenv import load_dotenv
@@ -57,27 +57,21 @@ class Generator:
5757

5858
_dump_data_filename = "init_params.json"
5959

60-
_default_generation_params: ClassVar[dict[str, Any]] = {
61-
"max_tokens": 150,
62-
"n": 1,
63-
"stop": None,
64-
"temperature": 0.7,
65-
}
66-
"""Default generation parameters for API requests."""
67-
6860
def __init__(
6961
self,
7062
base_url: str | None = None,
7163
model_name: str | None = None,
7264
use_cache: bool = True,
73-
**generation_params: Any, # noqa: ANN401
65+
client_params: dict[str, Any] | None = None,
66+
**generation_params: dict[str, Any],
7467
) -> None:
7568
"""Initialize the Generator with API configuration.
7669
7770
Args:
7871
base_url: OpenAI API compatible server URL.
7972
model_name: Name of the language model to use.
8073
use_cache: Whether to use caching for structured outputs.
74+
client_params: Additional parameters for client.
8175
**generation_params: Additional generation parameters to override defaults passed to OpenAI completions API.
8276
"""
8377
base_url = base_url or os.getenv("OPENAI_BASE_URL")
@@ -91,27 +85,23 @@ def __init__(
9185
self.base_url = base_url
9286
self.use_cache = use_cache
9387

94-
self.client = openai.OpenAI(base_url=base_url)
95-
self.async_client = openai.AsyncOpenAI(base_url=base_url)
88+
self.client = openai.OpenAI(base_url=base_url, **(client_params or {}))
89+
self.async_client = openai.AsyncOpenAI(base_url=base_url, **(client_params or {}))
90+
self.generation_params = generation_params
9691
self.cache = StructuredOutputCache(use_cache=use_cache)
9792

98-
self.generation_params = {
99-
**self._default_generation_params,
100-
**generation_params,
101-
} # https://stackoverflow.com/a/65539348
102-
10393
def get_chat_completion(self, messages: list[Message]) -> str:
10494
"""Prompt LLM and return its answer.
10595
10696
Args:
10797
messages: List of messages to send to the model.
10898
"""
10999
response = self.client.chat.completions.create(
110-
messages=messages, # type: ignore[arg-type]
100+
messages=messages, # type: ignore[call-overload]
111101
model=self.model_name,
112102
**self.generation_params,
113103
)
114-
return response.choices[0].message.content # type: ignore[return-value]
104+
return response.choices[0].message.content # type: ignore[no-any-return]
115105

116106
async def get_chat_completion_async(self, messages: list[Message]) -> str:
117107
"""Prompt LLM and return its answer asynchronously.
@@ -120,11 +110,15 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
120110
messages: List of messages to send to the model.
121111
"""
122112
response = await self.async_client.chat.completions.create(
123-
messages=messages, # type: ignore[arg-type]
113+
messages=messages, # type: ignore[call-overload]
124114
model=self.model_name,
125115
**self.generation_params,
126116
)
127-
return response.choices[0].message.content # type: ignore[return-value]
117+
118+
if response is None or not response.choices:
119+
msg = "No response received from the model."
120+
raise RuntimeError(msg)
121+
return response.choices[0].message.content # type: ignore[no-any-return]
128122

129123
def _create_retry_messages(self, error_message: str, raw: str | None) -> list[Message]:
130124
"""Create a follow-up message for retry with error details and schema."""
@@ -168,7 +162,7 @@ async def _get_structured_output_openai_async(
168162
model=self.model_name,
169163
messages=messages, # type: ignore[arg-type]
170164
response_format=output_model,
171-
**self.generation_params,
165+
**self.generation_params, # type: ignore[arg-type]
172166
)
173167
raw = response.choices[0].message.content
174168
res = response.choices[0].message.parsed
@@ -194,12 +188,12 @@ async def _get_structured_output_vllm_async(
194188
json_schema = output_model.model_json_schema()
195189
response = await self.async_client.chat.completions.create(
196190
model=self.model_name,
197-
messages=messages, # type: ignore[arg-type]
191+
messages=messages, # type: ignore[call-overload]
198192
extra_body={"guided_json": json_schema},
199193
**self.generation_params,
200194
)
201195
raw = response.choices[0].message.content
202-
res = output_model.model_validate_json(raw) # type: ignore[arg-type]
196+
res = output_model.model_validate_json(raw)
203197
except (ValidationError, ValueError) as e:
204198
msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}"
205199
logger.warning(msg)
@@ -252,6 +246,10 @@ async def get_structured_output_async(
252246
current_messages.extend(self._create_retry_messages(error, raw))
253247

254248
if res is None:
249+
msg = (
250+
f"Failed to generate valid structured output after {max_retries + 1} attempts.\n"
251+
f"Messages: {current_messages}"
252+
)
255253
logger.exception(msg)
256254
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
257255

@@ -281,7 +279,7 @@ def _get_structured_output_openai_sync(
281279
model=self.model_name,
282280
messages=messages, # type: ignore[arg-type]
283281
response_format=output_model,
284-
**self.generation_params,
282+
**self.generation_params, # type: ignore[arg-type]
285283
)
286284
raw = response.choices[0].message.content
287285
res = response.choices[0].message.parsed
@@ -307,12 +305,12 @@ def _get_structured_output_vllm_sync(
307305
json_schema = output_model.model_json_schema()
308306
response = self.client.chat.completions.create(
309307
model=self.model_name,
310-
messages=messages, # type: ignore[arg-type]
308+
messages=messages, # type: ignore[call-overload]
311309
extra_body={"guided_json": json_schema},
312310
**self.generation_params,
313311
)
314312
raw = response.choices[0].message.content
315-
res = output_model.model_validate_json(raw) # type: ignore[arg-type]
313+
res = output_model.model_validate_json(raw)
316314
except (ValidationError, ValueError) as e:
317315
msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}"
318316
logger.warning(msg)
@@ -365,6 +363,7 @@ def get_structured_output_sync(
365363
current_messages.extend(self._create_retry_messages(error, raw))
366364

367365
if res is None:
366+
msg = "Structured output returned None but no error was caught."
368367
logger.exception(msg)
369368
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
370369

autointent/generation/chat_templates/_intent_descriptions.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,77 +2,63 @@
22

33
from pydantic import BaseModel, field_validator
44

5-
PROMPT_DESCRIPTION = """
5+
from autointent.generation.chat_templates import Message, Role
6+
7+
PROMPT_DESCRIPTION_SYSTEM = """
68
Your task is to write a description of the intent.
79
8-
You are given the name of the intent, user intentions related to it, and
9-
regular expressions that match user utterances. The description should be:
10+
You are given the name of the intent, user intentions related to it. The description should be:
1011
1) In declarative form.
1112
2) No more than one sentence.
12-
3) In the language in which the utterances or regular expressions are written.
13+
3) In the language in which the utterances.
1314
1415
Remember:
1516
- Respond with just the description, no extra details.
16-
- Keep in mind that either the names, user queries, or regex patterns may not be provided.
17+
- Keep in mind that either the names or user queries may not be provided.
1718
1819
For example:
1920
21+
Input:
2022
name:
2123
activate_my_card
2224
user utterances:
23-
Please help me with my card. It won't activate.
24-
I tried but am unable to activate my card.
25-
I want to start using my card.
26-
regex patterns:
27-
(activate.*card)|(start.*using.*card)
28-
description:
25+
- Please help me with my card. It won't activate.
26+
- I tried but am unable to activate my card.
27+
- I want to start using my card.
28+
29+
Output:
2930
User wants to activate his card.
3031
32+
Input:
3133
name:
3234
beneficiary_not_allowed
3335
user utterances:
3436
35-
regex patterns:
36-
(not.*allowed.*beneficiary)|(cannot.*add.*beneficiary)
37-
description:
37+
Output:
3838
User wants to know why his beneficiary is not allowed.
39-
40-
name:
41-
vacation_registration
42-
user utterances:
43-
как оформить отпуск
44-
в какие даты надо оформить отпуск
45-
как запланировать отпуск
46-
regex patterns:
47-
48-
description:
49-
Пользователь спрашивает про оформление отпуска.
50-
39+
"""
40+
PROMPT_DESCRIPTION_USER = """
5141
name:
5242
{intent_name}
5343
user utterances:
5444
{user_utterances}
55-
regex patterns:
56-
{regex_patterns}
57-
description:
58-
5945
"""
6046

6147

6248
class PromptDescription(BaseModel):
6349
"""Prompt description configuration."""
6450

65-
text: str = PROMPT_DESCRIPTION
51+
system_text: str = PROMPT_DESCRIPTION_SYSTEM
52+
user_text: str = PROMPT_DESCRIPTION_USER
6653
"""
6754
The template for the prompt to generate descriptions for intents.
6855
Should include placeholders for {intent_name} and {user_utterances}.
6956
- `{intent_name}` will be replaced with the name of the intent.
7057
- `{user_utterances}` will be replaced with the user utterances related to the intent.
71-
- (optionally) `{regex_patterns}` will be replaced with the regular expressions that match user utterances.
7258
"""
7359

7460
@classmethod
75-
@field_validator("text")
61+
@field_validator("user_text")
7662
def check_valid_prompt(cls, value: str) -> str:
7763
"""Validate the prompt description template.
7864
@@ -89,3 +75,13 @@ def check_valid_prompt(cls, value: str) -> str:
8975
)
9076
raise ValueError(text_error)
9177
return value
78+
79+
def to_messages(self, intent_name: str | None, utterances: list[str]) -> list[Message]:
80+
user_message_content = self.user_text.format(
81+
intent_name=intent_name,
82+
user_utterances="\n - ".join(utterances),
83+
)
84+
return [
85+
Message(role=Role.SYSTEM, content=self.system_text),
86+
Message(role=Role.USER, content=user_message_content),
87+
]

autointent/generation/intents/_description_generation.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
import random
1010
from collections import defaultdict
1111

12-
from openai import AsyncOpenAI
13-
1412
from autointent import Dataset
13+
from autointent.generation import Generator
1514
from autointent.generation.chat_templates import PromptDescription
1615
from autointent.schemas import Intent, Sample
1716

@@ -41,55 +40,36 @@ def group_utterances_by_label(samples: list[Sample]) -> dict[int, list[str]]:
4140

4241

4342
async def create_intent_description(
44-
client: AsyncOpenAI,
43+
client: Generator,
4544
intent_name: str | None,
4645
utterances: list[str],
47-
regex_patterns: list[str],
4846
prompt: PromptDescription,
49-
model_name: str,
5047
) -> str:
5148
"""Generate a description for a specific intent using an OpenAI model.
5249
5350
Args:
5451
client: OpenAI client instance for model communication.
5552
intent_name: Name of the intent to describe (empty string if None).
5653
utterances: Example utterances related to the intent.
57-
regex_patterns: Regular expression patterns associated with the intent.
5854
prompt: Template for model prompt with placeholders for intent_name,
5955
user_utterances, and regex_patterns.
60-
model_name: Identifier of the OpenAI model to use.
6156
6257
Raises:
6358
TypeError: If the model response is not a string.
6459
"""
6560
intent_name = intent_name if intent_name is not None else ""
6661
utterances = random.sample(utterances, min(5, len(utterances)))
67-
regex_patterns = random.sample(regex_patterns, min(3, len(regex_patterns)))
6862

69-
content = prompt.text.format(
70-
intent_name=intent_name,
71-
user_utterances="\n".join(utterances),
72-
regex_patterns="\n".join(regex_patterns),
73-
)
74-
chat_completion = await client.chat.completions.create(
75-
messages=[{"role": "user", "content": content}],
76-
model=model_name,
77-
temperature=0.2,
63+
return await client.get_chat_completion_async(
64+
messages=prompt.to_messages(intent_name, utterances),
7865
)
79-
result = chat_completion.choices[0].message.content
80-
81-
if not isinstance(result, str):
82-
error_text = f"Unexpected response type: expected str, got {type(result).__name__}"
83-
raise TypeError(error_text)
84-
return result
8566

8667

8768
async def generate_intent_descriptions(
88-
client: AsyncOpenAI,
69+
client: Generator,
8970
intent_utterances: dict[int, list[str]],
9071
intents: list[Intent],
9172
prompt: PromptDescription,
92-
model_name: str,
9373
) -> list[Intent]:
9474
"""Generate descriptions for multiple intents using an OpenAI model.
9575
@@ -99,22 +79,18 @@ async def generate_intent_descriptions(
9979
intents: List of intents needing descriptions.
10080
prompt: Template for model prompt with placeholders for intent_name,
10181
user_utterances, and regex_patterns.
102-
model_name: Name of the OpenAI model to use.
10382
"""
10483
tasks = []
10584
for intent in intents:
10685
if intent.description is not None:
10786
continue
10887
utterances = intent_utterances.get(intent.id, [])
109-
regex_patterns = intent.regex_full_match + intent.regex_partial_match
11088
task = asyncio.create_task(
11189
create_intent_description(
11290
client=client,
11391
intent_name=intent.name,
11492
utterances=utterances,
115-
regex_patterns=regex_patterns,
11693
prompt=prompt,
117-
model_name=model_name,
11894
),
11995
)
12096
tasks.append((intent, task))
@@ -127,8 +103,7 @@ async def generate_intent_descriptions(
127103

128104
def generate_descriptions(
129105
dataset: Dataset,
130-
client: AsyncOpenAI,
131-
model_name: str,
106+
client: Generator,
132107
prompt: PromptDescription | None = None,
133108
) -> Dataset:
134109
"""Add LLM-generated text descriptions to dataset's intents.
@@ -138,7 +113,6 @@ def generate_descriptions(
138113
client: OpenAI client for generating descriptions.
139114
prompt: Template for model prompt with placeholders for intent_name,
140115
user_utterances, and regex_patterns.
141-
model_name: OpenAI model identifier for generating descriptions.
142116
143117
See :ref:`intent_description_generation` tutorial.
144118
"""
@@ -149,6 +123,6 @@ def generate_descriptions(
149123
if prompt is None:
150124
prompt = PromptDescription()
151125
dataset.intents = asyncio.run(
152-
generate_intent_descriptions(client, intent_utterances, dataset.intents, prompt, model_name),
126+
generate_intent_descriptions(client, intent_utterances, dataset.intents, prompt),
153127
)
154128
return dataset

0 commit comments

Comments
 (0)