22import logging
33from functools import partial
44from pathlib import Path
5- from typing import Literal , overload
5+ from typing import Literal , TypedDict , cast , overload
66
77import aiometer
88import numpy as np
99import numpy .typing as npt
1010import openai
1111import torch
12+ from typing_extensions import NotRequired
1213
1314from autointent ._hash import Hasher
1415from autointent .configs import TaskTypeEnum
2021logger = logging .getLogger (__name__ )
2122
2223
24+ class EmbeddingsCreateKwargs (TypedDict ):
25+ input : list [str ]
26+ model : str
27+ dimensions : NotRequired [int ]
28+
29+
2330class OpenaiEmbeddingBackend (BaseEmbeddingBackend ):
2431 """OpenAI-based embedding backend implementation."""
2532
@@ -30,9 +37,10 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None:
3037 config: Configuration for OpenAI embeddings.
3138 """
3239 self .config = config
33- self ._client = None
34- self ._async_client = None
35- self ._event_loop = None
40+ self ._client : openai .OpenAI | None = None
41+ self ._async_client : openai .AsyncOpenAI | None = None
42+ self ._event_loop : asyncio .AbstractEventLoop | None = None
43+
3644 if config .max_concurrent is not None :
3745 self ._init_event_loop ()
3846
@@ -124,7 +132,7 @@ def embed(
124132 embeddings_path = get_embeddings_path (hasher .hexdigest ())
125133 if embeddings_path .exists ():
126134 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
127- embeddings_np = np .load (embeddings_path ). astype ( np . float32 )
135+ embeddings_np = cast ( npt . NDArray [ np .float32 ], np . load (embeddings_path ))
128136 if return_tensors :
129137 return torch .from_numpy (embeddings_np )
130138 return embeddings_np
@@ -162,7 +170,7 @@ def _process_embeddings_sync(self, utterances: list[str]) -> np.ndarray:
162170 batch = utterances [i : i + self .config .batch_size ]
163171
164172 # Prepare API call parameters
165- kwargs = {
173+ kwargs : EmbeddingsCreateKwargs = {
166174 "input" : batch ,
167175 "model" : self .config .model_name ,
168176 }
@@ -198,6 +206,9 @@ def _process_embeddings_async(self, utterances: list[str]) -> np.ndarray:
198206 max_at_once = self .config .max_concurrent ,
199207 max_per_second = self .config .max_per_second ,
200208 )
209+ if self ._event_loop is None :
210+ msg = "Event loop is not initialized"
211+ raise RuntimeError (msg )
201212 batch_results = self ._event_loop .run_until_complete (task )
202213
203214 # Flatten results
@@ -210,7 +221,7 @@ async def _process_batch_async(self, batch: list[str]) -> list[list[float]]:
210221 client = self ._get_async_client ()
211222
212223 # Prepare API call parameters
213- kwargs = {
224+ kwargs : EmbeddingsCreateKwargs = {
214225 "input" : batch ,
215226 "model" : self .config .model_name ,
216227 }
@@ -246,7 +257,7 @@ def similarity(
246257
247258 # Calculate cosine similarity
248259 similarity_matrix = np .dot (normalized1 , normalized2 .T )
249- return similarity_matrix . astype ( np .float32 )
260+ return cast ( npt . NDArray [ np .float32 ], similarity_matrix )
250261
251262 def dump (self , path : Path ) -> None :
252263 """Save the backend state to disk.
0 commit comments