Skip to content

Commit d15f1e2

Browse files
authored
Feat/llm description scorer (#236)
* decompose `DescriptionScorer` into two scorers * first version of llm description scorer * add concurrency * tweak prompt a little bit * fix imports * fix typing * upd docstring to llm description scorer * refactor description scorer tests * fix stupid bugs * upd optimization test config * implement generator dumper * fix type errors * skip LLMScorer test if openai is not available * fix presets test * fix config test * upd inference test
1 parent 05fee79 commit d15f1e2

28 files changed

+1003
-392
lines changed

autointent/_dump_tools/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .main import Dumper
2-
from .unit_dumpers import PydanticModelDumper
32

4-
__all__ = ["Dumper", "PydanticModelDumper"]
3+
__all__ = ["Dumper"]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Separate file to fix circular import error."""
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from autointent.generation import Generator
7+
8+
from .base import BaseObjectDumper
9+
10+
11+
class GeneratorDumper(BaseObjectDumper[Generator]):
12+
dir_or_file_name = "generators"
13+
14+
@staticmethod
15+
def dump(obj: Generator, path: Path, exists_ok: bool) -> None:
16+
obj.dump(path, exist_ok=exists_ok)
17+
18+
@staticmethod
19+
def load(path: Path, **kwargs: Any) -> Generator: # noqa: ANN401, ARG004
20+
return Generator.load(path)
21+
22+
@classmethod
23+
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
24+
return isinstance(obj, Generator)

autointent/_dump_tools/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from autointent.schemas import TagsList
1111

1212
from .base import BaseObjectDumper, ModuleAttributes, ModuleSimpleAttributes
13+
from .generator_dumper import GeneratorDumper
1314
from .unit_dumpers import (
1415
ArraysDumper,
1516
CatBoostDumper,
@@ -46,6 +47,7 @@ class Dumper:
4647
HFTokenizerDumper,
4748
TorchModelDumper,
4849
CatBoostDumper,
50+
GeneratorDumper,
4951
]
5052

5153
@staticmethod

autointent/_presets/heavy.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ search_space:
1313
k:
1414
low: 1
1515
high: 20
16-
- module_name: description
16+
- module_name: description_bi
17+
temperature:
18+
low: 0.01
19+
high: 10
20+
log: true
21+
- module_name: description_cross
1722
temperature:
1823
low: 0.01
1924
high: 10

autointent/generation/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
See :ref:`data-aug-tuts`.
44
"""
55

6-
from ._generator import Generator
6+
from ._cache import StructuredOutputCache
7+
from ._generator import Generator, RetriesExceededError
78

8-
__all__ = ["Generator"]
9+
__all__ = ["Generator", "RetriesExceededError"]

autointent/generation/_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dotenv import load_dotenv
1010
from pydantic import BaseModel, ValidationError
1111

12-
from autointent._dump_tools import PydanticModelDumper
12+
from autointent._dump_tools.unit_dumpers import PydanticModelDumper
1313
from autointent._hash import Hasher
1414
from autointent.generation.chat_templates import Message
1515

autointent/generation/_generator.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Wrapper class for accessing OpenAI API."""
22

3+
import json
34
import logging
45
import os
6+
from pathlib import Path
57
from textwrap import dedent
6-
from typing import Any, ClassVar, Literal, TypeVar
8+
from typing import Any, ClassVar, Literal, TypedDict, TypeVar
79

810
import openai
911
from dotenv import load_dotenv
@@ -22,6 +24,27 @@
2224
"""Type variable for Pydantic models used in structured output generation."""
2325

2426

27+
class GeneratorDumpData(TypedDict):
28+
use_cache: bool
29+
model_name: str
30+
base_url: str | None
31+
generation_params: dict[str, Any]
32+
33+
34+
class RetriesExceededError(RuntimeError):
35+
"""Exception raised when LLM call fails after all retry attempts."""
36+
37+
def __init__(self, max_retries: int, messages: list[Message]) -> None:
38+
"""Initialize the error with retry count and messages.
39+
40+
Args:
41+
max_retries: Maximum number of retry attempts that were made
42+
messages: Messages that were sent to the LLM
43+
"""
44+
msg = f"LLM call failed after {max_retries + 1} attempts. Messages: {messages}"
45+
super().__init__(msg)
46+
47+
2548
class Generator:
2649
"""Wrapper class for accessing OpenAI API.
2750
@@ -32,6 +55,8 @@ class Generator:
3255
**generation_params: kwargs that will be sent with a request to the endpoint.
3356
"""
3457

58+
_dump_data_filename = "init_params.json"
59+
3560
_default_generation_params: ClassVar[dict[str, Any]] = {
3661
"max_tokens": 150,
3762
"n": 1,
@@ -57,17 +82,23 @@ def __init__(
5782
"""
5883
base_url = base_url or os.getenv("OPENAI_BASE_URL")
5984
model_name = model_name or os.getenv("OPENAI_MODEL_NAME")
85+
6086
if model_name is None:
6187
msg = "Specify model_name arg or OPENAI_MODEL_NAME environment variable"
6288
raise ValueError(msg)
89+
6390
self.model_name = model_name
91+
self.base_url = base_url
92+
self.use_cache = use_cache
93+
6494
self.client = openai.OpenAI(base_url=base_url)
6595
self.async_client = openai.AsyncOpenAI(base_url=base_url)
96+
self.cache = StructuredOutputCache(use_cache=use_cache)
97+
6698
self.generation_params = {
6799
**self._default_generation_params,
68100
**generation_params,
69101
} # https://stackoverflow.com/a/65539348
70-
self.cache = StructuredOutputCache(use_cache=use_cache)
71102

72103
def get_chat_completion(self, messages: list[Message]) -> str:
73104
"""Prompt LLM and return its answer.
@@ -221,12 +252,8 @@ async def get_structured_output_async(
221252
current_messages.extend(self._create_retry_messages(error, raw))
222253

223254
if res is None:
224-
msg = (
225-
f"Failed to generate valid structured output after {max_retries + 1} attempts.\n"
226-
f"Messages: {current_messages}"
227-
)
228255
logger.exception(msg)
229-
raise RuntimeError(msg)
256+
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
230257

231258
# Cache the successful result
232259
self.cache.set(messages, output_model, backend, self.generation_params, res)
@@ -338,14 +365,32 @@ def get_structured_output_sync(
338365
current_messages.extend(self._create_retry_messages(error, raw))
339366

340367
if res is None:
341-
msg = (
342-
f"Failed to generate valid structured output after {max_retries + 1} attempts.\n"
343-
f"Messages: {current_messages}"
344-
)
345368
logger.exception(msg)
346-
raise RuntimeError(msg)
369+
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
347370

348371
# Cache the successful result
349372
self.cache.set(messages, output_model, backend, self.generation_params, res)
350373

351374
return res
375+
376+
def dump(self, path: Path, exist_ok: bool = True) -> None:
377+
data: GeneratorDumpData = {
378+
"base_url": self.base_url,
379+
"generation_params": self.generation_params,
380+
"model_name": self.model_name,
381+
"use_cache": self.use_cache,
382+
}
383+
384+
path.mkdir(exist_ok=exist_ok, parents=True)
385+
386+
with (path / self._dump_data_filename).open("w", encoding="utf-8") as file:
387+
json.dump(data, file, indent=4, ensure_ascii=False)
388+
389+
@classmethod
390+
def load(cls, path: Path) -> "Generator":
391+
with (path / cls._dump_data_filename).open(encoding="utf-8") as file:
392+
data: GeneratorDumpData = json.load(file)
393+
394+
generation_params = data.pop("generation_params") # type: ignore[misc]
395+
396+
return cls(**data, **generation_params)

autointent/modules/__init__.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
from .scoring import (
1616
BERTLoRAScorer,
1717
BertScorer,
18+
BiEncoderDescriptionScorer,
1819
CatBoostScorer,
1920
CNNScorer,
20-
DescriptionScorer,
21+
CrossEncoderDescriptionScorer,
2122
DNNCScorer,
2223
KNNScorer,
2324
LinearScorer,
25+
LLMDescriptionScorer,
2426
MLKnnScorer,
2527
PTuningScorer,
2628
RerankScorer,
@@ -47,7 +49,9 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4749
DNNCScorer,
4850
KNNScorer,
4951
LinearScorer,
50-
DescriptionScorer,
52+
BiEncoderDescriptionScorer,
53+
CrossEncoderDescriptionScorer,
54+
LLMDescriptionScorer,
5155
RerankScorer,
5256
SklearnScorer,
5357
MLKnnScorer,
@@ -62,28 +66,3 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
6266
DECISION_MODULES: dict[str, type[BaseDecision]] = _create_modules_dict(
6367
[ArgmaxDecision, JinoosDecision, ThresholdDecision, TunableDecision, AdaptiveDecision],
6468
)
65-
66-
67-
__all__ = [
68-
"AdaptiveDecision",
69-
"ArgmaxDecision",
70-
"BaseDecision",
71-
"BaseEmbedding",
72-
"BaseModule",
73-
"BaseRegex",
74-
"BaseScorer",
75-
"CatBoostScorer",
76-
"DNNCScorer",
77-
"DescriptionScorer",
78-
"JinoosDecision",
79-
"KNNScorer",
80-
"LinearScorer",
81-
"LogregAimedEmbedding",
82-
"MLKnnScorer",
83-
"RerankScorer",
84-
"RetrievalAimedEmbedding",
85-
"SimpleRegex",
86-
"SklearnScorer",
87-
"ThresholdDecision",
88-
"TunableDecision",
89-
]

autointent/modules/scoring/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._bert import BertScorer
22
from ._catboost import CatBoostScorer
3-
from ._description import DescriptionScorer
3+
from ._description import BiEncoderDescriptionScorer, CrossEncoderDescriptionScorer, LLMDescriptionScorer
44
from ._dnnc import DNNCScorer
55
from ._knn import KNNScorer, RerankScorer
66
from ._linear import LinearScorer
@@ -13,11 +13,13 @@
1313
__all__ = [
1414
"BERTLoRAScorer",
1515
"BertScorer",
16+
"BiEncoderDescriptionScorer",
1617
"CNNScorer",
1718
"CatBoostScorer",
19+
"CrossEncoderDescriptionScorer",
1820
"DNNCScorer",
19-
"DescriptionScorer",
2021
"KNNScorer",
22+
"LLMDescriptionScorer",
2123
"LinearScorer",
2224
"MLKnnScorer",
2325
"PTuningScorer",
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from .description import DescriptionScorer
1+
from .bi_encoder import BiEncoderDescriptionScorer
2+
from .cross_encoder import CrossEncoderDescriptionScorer
3+
from .llm_encoder import LLMDescriptionScorer
24

3-
__all__ = ["DescriptionScorer"]
5+
__all__ = ["BiEncoderDescriptionScorer", "CrossEncoderDescriptionScorer", "LLMDescriptionScorer"]

0 commit comments

Comments
 (0)