Skip to content

Commit 89b313d

Browse files
committed
add openai embedding backend
1 parent 29981f6 commit 89b313d

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import logging
2+
import time
3+
from pathlib import Path
4+
from typing import Literal, overload
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import openai
9+
import torch
10+
from appdirs import user_cache_dir
11+
12+
from autointent._hash import Hasher
13+
from autointent.configs import TaskTypeEnum
14+
from autointent.configs._embedder import OpenaiEmbeddingConfig
15+
16+
from .base import BaseEmbeddingBackend
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def _get_embeddings_path(filename: str) -> Path:
22+
"""Get the path to the embeddings file.
23+
24+
This function constructs the full path to an embeddings file stored
25+
in a specific directory under the user's home directory. The embeddings
26+
file is named based on the provided filename, with the `.npy` extension
27+
added.
28+
29+
Args:
30+
filename: The name of the embeddings file (without extension).
31+
32+
Returns:
33+
The full path to the embeddings file.
34+
"""
35+
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
36+
37+
38+
class OpenaiEmbeddingBackend(BaseEmbeddingBackend):
39+
"""OpenAI-based embedding backend implementation."""
40+
41+
def __init__(self, config: OpenaiEmbeddingConfig) -> None:
42+
"""Initialize the OpenAI backend.
43+
44+
Args:
45+
config: Configuration for OpenAI embeddings.
46+
"""
47+
self.config = config
48+
self._client = None
49+
50+
def _get_client(self) -> openai.OpenAI:
51+
"""Get or create OpenAI client instance."""
52+
if self._client is None:
53+
self._client = openai.OpenAI(
54+
api_key=self.config.api_key,
55+
timeout=self.config.timeout,
56+
max_retries=self.config.max_retries,
57+
)
58+
return self._client
59+
60+
def clear_ram(self) -> None:
61+
"""Clear the backend from RAM. For OpenAI, this is a no-op."""
62+
# OpenAI API doesn't store models in RAM, so nothing to clear
63+
64+
def get_hash(self) -> int:
65+
"""Compute a hash value for identifying embedding model."""
66+
hasher = Hasher()
67+
hasher.update(self.config.model_name)
68+
hasher.update(str(self.config.dimensions))
69+
return hasher.intdigest()
70+
71+
@overload
72+
def embed(
73+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[True]
74+
) -> torch.Tensor: ...
75+
76+
@overload
77+
def embed(
78+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[False] = False
79+
) -> npt.NDArray[np.float32]: ...
80+
81+
def embed(
82+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
83+
) -> npt.NDArray[np.float32] | torch.Tensor:
84+
"""Calculate embeddings for a list of utterances.
85+
86+
Args:
87+
utterances: List of input texts to calculate embeddings for.
88+
task_type: Type of task for which embeddings are calculated (unused for OpenAI).
89+
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.
90+
91+
Returns:
92+
A numpy array or PyTorch tensor of embeddings.
93+
"""
94+
if len(utterances) == 0:
95+
msg = "Empty input"
96+
logger.error(msg)
97+
raise ValueError(msg)
98+
99+
if self.config.use_cache:
100+
logger.debug("Using cached embeddings for %s", self.config.model_name)
101+
hasher = Hasher()
102+
hasher.update(self.get_hash())
103+
hasher.update(utterances)
104+
105+
embeddings_path = _get_embeddings_path(hasher.hexdigest())
106+
if embeddings_path.exists():
107+
logger.debug("loading embeddings from %s", str(embeddings_path))
108+
embeddings_np = np.load(embeddings_path).astype(np.float32)
109+
if return_tensors:
110+
return torch.from_numpy(embeddings_np)
111+
return embeddings_np
112+
113+
client = self._get_client()
114+
115+
logger.debug(
116+
"Calculating embeddings with OpenAI model %s, batch_size=%d, dimensions=%s",
117+
self.config.model_name,
118+
self.config.batch_size,
119+
str(self.config.dimensions),
120+
)
121+
122+
all_embeddings = []
123+
124+
# Process in batches
125+
for i in range(0, len(utterances), self.config.batch_size):
126+
batch = utterances[i : i + self.config.batch_size]
127+
128+
# Prepare API call parameters
129+
kwargs = {
130+
"input": batch,
131+
"model": self.config.model_name,
132+
}
133+
if self.config.dimensions is not None:
134+
kwargs["dimensions"] = self.config.dimensions
135+
136+
try:
137+
response = client.embeddings.create(**kwargs)
138+
batch_embeddings = [data.embedding for data in response.data]
139+
all_embeddings.extend(batch_embeddings)
140+
141+
# Add small delay to avoid rate limiting
142+
if i + self.config.batch_size < len(utterances):
143+
time.sleep(0.1)
144+
145+
except Exception as e:
146+
msg = "Error calling OpenAI API"
147+
logger.exception(msg)
148+
raise RuntimeError(msg) from e
149+
150+
embeddings_np = np.array(all_embeddings, dtype=np.float32)
151+
152+
if self.config.use_cache:
153+
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
154+
np.save(embeddings_path, embeddings_np)
155+
156+
if return_tensors:
157+
return torch.from_numpy(embeddings_np)
158+
return embeddings_np
159+
160+
def similarity(
161+
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
162+
) -> npt.NDArray[np.float32]:
163+
"""Calculate cosine similarity between two sets of embeddings.
164+
165+
Args:
166+
embeddings1: First set of embeddings (size n).
167+
embeddings2: Second set of embeddings (size m).
168+
169+
Returns:
170+
A numpy array of similarities (size n x m).
171+
"""
172+
# Normalize embeddings
173+
norm1 = np.linalg.norm(embeddings1, axis=1, keepdims=True)
174+
norm2 = np.linalg.norm(embeddings2, axis=1, keepdims=True)
175+
176+
normalized1 = embeddings1 / norm1
177+
normalized2 = embeddings2 / norm2
178+
179+
# Calculate cosine similarity
180+
similarity_matrix = np.dot(normalized1, normalized2.T)
181+
return similarity_matrix.astype(np.float32)

0 commit comments

Comments
 (0)