Skip to content

Commit 9dfa81f

Browse files
mypy fix
1 parent c18698e commit 9dfa81f

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

autointent/generation/utterances/_adversarial/human_utterance_generator.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from .critic_human_like import CriticHumanLike
1818

1919
logger = logging.getLogger(__name__)
20+
21+
2022
class HumanUtteranceGenerator:
2123
"""Generator of human-like utterances.
2224
@@ -110,17 +112,17 @@ async def augment_async(
110112
for sample in original_split:
111113
class_to_samples[sample["label"]].append(sample["utterance"])
112114

113-
114-
async def generate_one(intent_id: str, intent_name: str) -> list[dict]:
115+
async def generate_one(intent_id: str, intent_name: str) -> list[dict[str, str]]:
115116
if intent_name is None:
116117
logger.warning("Intent with id %s has no name! Skipping it...", intent_id)
117-
generated = []
118+
return []
119+
generated: list[dict[str, str]] = []
118120
attempts = 0
119121
seed_utterances = class_to_samples[intent_id]
120122
while len(generated) < n_final_per_class and attempts < n_final_per_class * 3:
121123
attempts += 1
122124
seed_examples = random.sample(seed_utterances, k=min(3, len(seed_utterances)))
123-
rejected = []
125+
rejected: list[str] = []
124126

125127
for _ in range(3):
126128
prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected)
@@ -130,15 +132,14 @@ async def generate_one(intent_id: str, intent_name: str) -> list[dict]:
130132
break
131133
rejected.append(utterance)
132134
return generated
135+
133136
tasks = [
134-
partial(generate_one, intent_id, intent_name)
135-
for intent_id, intent_name in id_to_name.items()
136-
if class_to_samples.get(intent_id) and intent_name is not None
137+
partial(generate_one, str(intent_id), intent_name)
138+
for intent_id, intent_name in id_to_name.items()
139+
if class_to_samples.get(intent_id) and intent_name is not None
137140
]
138141

139-
results = await aiometer.run_all(
140-
tasks, max_at_once=5, max_per_second=10
141-
)
142+
results = await aiometer.run_all(tasks, max_at_once=5, max_per_second=10)
142143

143144
for result in results:
144145
new_samples.extend(result)

0 commit comments

Comments
 (0)