Skip to content

Commit 4dc4659

Browse files
committed
fix typing
1 parent 91b43b0 commit 4dc4659

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

autointent/generation/_generator.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def get_chat_completion(self, messages: list[Message]) -> str:
6767
messages: List of messages to send to the model.
6868
"""
6969
response = self.client.chat.completions.create(
70-
messages=messages, # type: ignore[arg-type]
70+
messages=messages, # type: ignore[call-overload]
7171
model=self.model_name,
7272
**self.generation_params,
7373
)
74-
return response.choices[0].message.content # type: ignore[return-value]
74+
return response.choices[0].message.content # type: ignore[no-any-return]
7575

7676
async def get_chat_completion_async(self, messages: list[Message]) -> str:
7777
"""Prompt LLM and return its answer asynchronously.
@@ -80,11 +80,15 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
8080
messages: List of messages to send to the model.
8181
"""
8282
response = await self.async_client.chat.completions.create(
83-
messages=messages, # type: ignore[arg-type]
83+
messages=messages, # type: ignore[call-overload]
8484
model=self.model_name,
8585
**self.generation_params,
8686
)
87-
return response.choices[0].message.content # type: ignore[return-value]
87+
88+
if response is None or not response.choices:
89+
msg = "No response received from the model."
90+
raise RuntimeError(msg)
91+
return response.choices[0].message.content # type: ignore[no-any-return]
8892

8993
def _create_retry_messages(self, error_message: str, raw: str | None) -> list[Message]:
9094
"""Create a follow-up message for retry with error details and schema."""
@@ -128,7 +132,7 @@ async def _get_structured_output_openai_async(
128132
model=self.model_name,
129133
messages=messages, # type: ignore[arg-type]
130134
response_format=output_model,
131-
**self.generation_params,
135+
**self.generation_params, # type: ignore[arg-type]
132136
)
133137
raw = response.choices[0].message.content
134138
res = response.choices[0].message.parsed
@@ -154,7 +158,7 @@ async def _get_structured_output_vllm_async(
154158
json_schema = output_model.model_json_schema()
155159
response = await self.async_client.chat.completions.create(
156160
model=self.model_name,
157-
messages=messages, # type: ignore[arg-type]
161+
messages=messages, # type: ignore[call-overload]
158162
extra_body={"guided_json": json_schema},
159163
**self.generation_params,
160164
)
@@ -245,7 +249,7 @@ def _get_structured_output_openai_sync(
245249
model=self.model_name,
246250
messages=messages, # type: ignore[arg-type]
247251
response_format=output_model,
248-
**self.generation_params,
252+
**self.generation_params, # type: ignore[arg-type]
249253
)
250254
raw = response.choices[0].message.content
251255
res = response.choices[0].message.parsed
@@ -271,7 +275,7 @@ def _get_structured_output_vllm_sync(
271275
json_schema = output_model.model_json_schema()
272276
response = self.client.chat.completions.create(
273277
model=self.model_name,
274-
messages=messages, # type: ignore[arg-type]
278+
messages=messages, # type: ignore[call-overload]
275279
extra_body={"guided_json": json_schema},
276280
**self.generation_params,
277281
)

autointent/generation/intents/_description_generation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,10 @@ async def create_intent_description(
6060
intent_name = intent_name if intent_name is not None else ""
6161
utterances = random.sample(utterances, min(5, len(utterances)))
6262

63-
result = await client.get_chat_completion_async(
63+
return await client.get_chat_completion_async(
6464
messages=prompt.to_messages(intent_name, utterances),
6565
)
6666

67-
if not isinstance(result, str):
68-
error_text = f"Unexpected response type: expected str, got {type(result).__name__}"
69-
raise TypeError(error_text)
70-
return result
71-
7267

7368
async def generate_intent_descriptions(
7469
client: Generator,

tests/generation/intents/test_description_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from unittest.mock import AsyncMock, Mock, patch
2+
from unittest.mock import AsyncMock, patch
33

44
import pytest
55

@@ -243,8 +243,8 @@ async def test_generate_intent_descriptions_empty_utterances_patterns():
243243
mock_create.assert_called_once_with(
244244
messages=[
245245
{
246-
"role": "system",
247-
"content": prompt.system_text,
246+
"role": "system",
247+
"content": prompt.system_text,
248248
},
249249
{
250250
"role": "user",

0 commit comments

Comments
 (0)