Skip to content

Commit a0bb255

Browse files
committed
fix some typing errors
1 parent 7176660 commit a0bb255

File tree

14 files changed

+82
-41
lines changed

14 files changed

+82
-41
lines changed

src/autointent/_wrappers/embedder/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def embed(
3838

3939
@abstractmethod
4040
def embed(
41-
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
41+
self,
42+
utterances: list[str],
43+
task_type: TaskTypeEnum | None = None,
44+
return_tensors: bool = False,
4245
) -> npt.NDArray[np.float32] | torch.Tensor:
4346
"""Calculate embeddings for a list of utterances.
4447

src/autointent/_wrappers/embedder/embedder.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import numpy as np
1414
import numpy.typing as npt
1515
import torch
16-
from typing_extensions import assert_never
1716

1817
from autointent.configs import EmbedderFineTuningConfig, TaskTypeEnum
1918
from autointent.configs._embedder import EmbedderConfig, OpenaiEmbeddingConfig, SentenceTransformerEmbeddingConfig
@@ -54,10 +53,9 @@ def _load_model(self) -> BaseEmbeddingBackend:
5453
if isinstance(self.config, OpenaiEmbeddingConfig):
5554
return OpenaiEmbeddingBackend(self.config)
5655
# Check if it's exactly the abstract base config (not a subclass)
57-
if type(self.config) is EmbedderConfig:
58-
msg = f"Cannot instantiate abstract EmbedderConfig: {self.config.__repr__()}"
59-
raise TypeError(msg)
60-
assert_never(self.config)
56+
57+
msg = f"Cannot instantiate abstract EmbedderConfig: {self.config.__repr__()}"
58+
raise TypeError(msg)
6159

6260
def _get_hash(self) -> int:
6361
"""Compute a hash value for the Embedder.
@@ -149,13 +147,9 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
149147
instance._backend = SentenceTransformerEmbeddingBackend.load(backend_path) # noqa: SLF001
150148
elif isinstance(config, OpenaiEmbeddingConfig):
151149
instance._backend = OpenaiEmbeddingBackend.load(backend_path) # noqa: SLF001
152-
# Check if it's exactly the abstract base config (not a subclass)
153-
elif type(config) is EmbedderConfig:
154-
# Handle abstract base config case
150+
else:
155151
msg = f"Cannot load abstract EmbedderConfig: {config.__repr__()}"
156152
raise TypeError(msg)
157-
else:
158-
assert_never(config)
159153

160154
return instance
161155

@@ -170,7 +164,10 @@ def embed(
170164
) -> npt.NDArray[np.float32]: ...
171165

172166
def embed(
173-
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
167+
self,
168+
utterances: list[str],
169+
task_type: TaskTypeEnum | None = None,
170+
return_tensors: bool = False,
174171
) -> npt.NDArray[np.float32] | torch.Tensor:
175172
"""Calculate embeddings for a list of utterances.
176173
@@ -182,7 +179,7 @@ def embed(
182179
Returns:
183180
A numpy array or PyTorch tensor of embeddings.
184181
"""
185-
return self._backend.embed(utterances, task_type, return_tensors=return_tensors)
182+
return self._backend.embed(utterances=utterances, task_type=task_type, return_tensors=return_tensors)
186183

187184
def similarity(
188185
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]

src/autointent/_wrappers/embedder/openai.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import logging
33
from functools import partial
44
from pathlib import Path
5-
from typing import Literal, overload
5+
from typing import Literal, TypedDict, cast, overload
66

77
import aiometer
88
import numpy as np
99
import numpy.typing as npt
1010
import openai
1111
import torch
12+
from typing_extensions import NotRequired
1213

1314
from autointent._hash import Hasher
1415
from autointent.configs import TaskTypeEnum
@@ -20,6 +21,12 @@
2021
logger = logging.getLogger(__name__)
2122

2223

24+
class EmbeddingsCreateKwargs(TypedDict):
25+
input: list[str]
26+
model: str
27+
dimensions: NotRequired[int]
28+
29+
2330
class OpenaiEmbeddingBackend(BaseEmbeddingBackend):
2431
"""OpenAI-based embedding backend implementation."""
2532

@@ -30,9 +37,10 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None:
3037
config: Configuration for OpenAI embeddings.
3138
"""
3239
self.config = config
33-
self._client = None
34-
self._async_client = None
35-
self._event_loop = None
40+
self._client: openai.OpenAI | None = None
41+
self._async_client: openai.AsyncOpenAI | None = None
42+
self._event_loop: asyncio.AbstractEventLoop | None = None
43+
3644
if config.max_concurrent is not None:
3745
self._init_event_loop()
3846

@@ -124,7 +132,7 @@ def embed(
124132
embeddings_path = get_embeddings_path(hasher.hexdigest())
125133
if embeddings_path.exists():
126134
logger.debug("loading embeddings from %s", str(embeddings_path))
127-
embeddings_np = np.load(embeddings_path).astype(np.float32)
135+
embeddings_np = cast(npt.NDArray[np.float32], np.load(embeddings_path))
128136
if return_tensors:
129137
return torch.from_numpy(embeddings_np)
130138
return embeddings_np
@@ -162,7 +170,7 @@ def _process_embeddings_sync(self, utterances: list[str]) -> np.ndarray:
162170
batch = utterances[i : i + self.config.batch_size]
163171

164172
# Prepare API call parameters
165-
kwargs = {
173+
kwargs: EmbeddingsCreateKwargs = {
166174
"input": batch,
167175
"model": self.config.model_name,
168176
}
@@ -198,6 +206,9 @@ def _process_embeddings_async(self, utterances: list[str]) -> np.ndarray:
198206
max_at_once=self.config.max_concurrent,
199207
max_per_second=self.config.max_per_second,
200208
)
209+
if self._event_loop is None:
210+
msg = "Event loop is not initialized"
211+
raise RuntimeError(msg)
201212
batch_results = self._event_loop.run_until_complete(task)
202213

203214
# Flatten results
@@ -210,7 +221,7 @@ async def _process_batch_async(self, batch: list[str]) -> list[list[float]]:
210221
client = self._get_async_client()
211222

212223
# Prepare API call parameters
213-
kwargs = {
224+
kwargs: EmbeddingsCreateKwargs = {
214225
"input": batch,
215226
"model": self.config.model_name,
216227
}
@@ -246,7 +257,7 @@ def similarity(
246257

247258
# Calculate cosine similarity
248259
similarity_matrix = np.dot(normalized1, normalized2.T)
249-
return similarity_matrix.astype(np.float32)
260+
return cast(npt.NDArray[np.float32], similarity_matrix)
250261

251262
def dump(self, path: Path) -> None:
252263
"""Save the backend state to disk.

src/autointent/_wrappers/vector_index/vector_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def add(self, texts: list[str], labels: ListOfLabels) -> None:
7373
msg = f"Texts and labels lengths mismatch: {len(texts)=} !] {len(labels)=}"
7474
raise ValueError(msg)
7575

76-
logger.debug("Adding embeddings to vector index %s", self.embedder.config.model_name)
76+
logger.debug("Adding embeddings to vector index")
7777
embeddings = self.embedder.embed(texts, TaskTypeEnum.passage)
7878

7979
if not hasattr(self, "index"):
@@ -85,7 +85,7 @@ def add(self, texts: list[str], labels: ListOfLabels) -> None:
8585

8686
def clear_ram(self) -> None:
8787
"""Clear the vector index from RAM."""
88-
logger.debug("Clearing vector index %s from RAM", self.embedder.config.model_name)
88+
logger.debug("Clearing vector index from RAM")
8989
self.embedder.clear_ram()
9090
self.index.clear_ram()
9191

src/autointent/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SentenceTransformerEmbeddingConfig,
77
TaskTypeEnum,
88
get_default_embedder_config,
9+
initialize_embedder_config,
910
)
1011
from ._inference_node import InferenceNodeConfig
1112
from ._optimization import DataConfig, HPOConfig, LoggingConfig
@@ -40,4 +41,5 @@
4041
"VocabConfig",
4142
"get_default_embedder_config",
4243
"get_default_vector_index_config",
44+
"initialize_embedder_config",
4345
]

src/autointent/configs/_embedder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import Any
23

34
from pydantic import BaseModel, Field
45

@@ -98,5 +99,15 @@ class OpenaiEmbeddingConfig(EmbedderConfig):
9899
)
99100

100101

101-
def get_default_embedder_config() -> EmbedderConfig:
102-
return SentenceTransformerEmbeddingConfig()
102+
def get_default_embedder_config(**kwargs: Any) -> EmbedderConfig: # noqa: ANN401
103+
return SentenceTransformerEmbeddingConfig(**kwargs)
104+
105+
106+
def initialize_embedder_config(values: dict[str, Any] | str | EmbedderConfig | None) -> EmbedderConfig:
107+
if values is None:
108+
return get_default_embedder_config()
109+
if isinstance(values, EmbedderConfig):
110+
return values.model_copy(deep=True)
111+
if isinstance(values, str):
112+
return get_default_embedder_config(model_name=values)
113+
return get_default_embedder_config(**values)

src/autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic import PositiveInt
1212

1313
from autointent import Context, Embedder
14-
from autointent.configs import EmbedderConfig, TaskTypeEnum
14+
from autointent.configs import EmbedderConfig, TaskTypeEnum, initialize_embedder_config
1515
from autointent.custom_types import FloatFromZeroToOne, ListOfLabels
1616
from autointent.modules.base import BaseScorer
1717

@@ -113,7 +113,7 @@ def __init__(
113113
msg = "Only catbooost text features will be used, `use_embedding_features` is ignored."
114114
logger.warning(msg)
115115

116-
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
116+
self.embedder_config = initialize_embedder_config(embedder_config)
117117
self.loss_function = loss_function
118118
self.verbose = verbose
119119
self.catboost_kwargs = catboost_kwargs or {}

src/autointent/modules/scoring/_description/bi_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import PositiveFloat
88

99
from autointent import Context, Embedder
10-
from autointent.configs import EmbedderConfig, TaskTypeEnum
10+
from autointent.configs import EmbedderConfig, TaskTypeEnum, initialize_embedder_config
1111

1212
from .base import BaseDescriptionScorer
1313

@@ -64,7 +64,7 @@ def __init__(
6464
multilabel: bool = False,
6565
) -> None:
6666
super().__init__(temperature=temperature, multilabel=multilabel)
67-
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
67+
self.embedder_config = initialize_embedder_config(embedder_config)
6868
self._embedder: Embedder | None = None
6969
self._description_vectors: NDArray[Any] | None = None
7070

src/autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from pydantic import PositiveInt
1010

1111
from autointent import Context, Ranker, VectorIndex
12-
from autointent.configs import CrossEncoderConfig, EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
12+
from autointent.configs import (
13+
CrossEncoderConfig,
14+
EmbedderConfig,
15+
VectorIndexConfig,
16+
get_default_vector_index_config,
17+
initialize_embedder_config,
18+
)
1319
from autointent.custom_types import Document, ListOfLabels
1420
from autointent.modules.base import BaseScorer
1521

@@ -67,7 +73,7 @@ def __init__(
6773
vector_index_config: VectorIndexConfig | None = None,
6874
) -> None:
6975
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
70-
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
76+
self.embedder_config = initialize_embedder_config(embedder_config)
7177
self.k = k
7278
self.vector_index_config = vector_index_config or get_default_vector_index_config()
7379

src/autointent/modules/scoring/_gcn/gcn_scorer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EmbedderConfig,
1212
TaskTypeEnum,
1313
TorchTrainingConfig,
14+
initialize_embedder_config,
1415
)
1516
from autointent.custom_types import ListOfLabels
1617
from autointent.modules.scoring._gcn.gcn_model import TextMLGCN
@@ -63,8 +64,8 @@ def __init__( # noqa: PLR0913
6364
) -> None:
6465
if gcn_hidden_dims is None:
6566
gcn_hidden_dims = [1024]
66-
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
67-
self.label_embedder_config = EmbedderConfig.from_search_config(label_embedder_config)
67+
self.embedder_config = initialize_embedder_config(embedder_config)
68+
self.label_embedder_config = initialize_embedder_config(label_embedder_config)
6869
self.gcn_hidden_dims = gcn_hidden_dims
6970
self.p_reweight = p_reweight
7071
self.tau_threshold = tau_threshold

0 commit comments

Comments
 (0)