Skip to content

Commit 6a4fd30

Browse files
Add option to import OpenAIEmbeddings from langchain_openai (#328)
`OpenAIEmbeddings` was removed from the main langchain repo and needs to be imported from `langchain_openai` for newer langchain versions.
1 parent db9eae1 commit 6a4fd30

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

apis/python/src/tiledb/vector_search/embeddings/langchain_embedding.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
# class LangChainEmbedding(ObjectEmbedding):
99
class LangChainEmbedding:
1010
"""
11-
Embedding functions from `langchain.embeddings` package.
11+
Embedding functions from Langchain.
12+
13+
This attempts to import the embedding_class from the following modules:
14+
- langchain_openai
15+
- langchain.embeddings
1216
"""
1317

1418
def __init__(
@@ -37,9 +41,14 @@ def vector_type(self) -> np.dtype:
3741
def load(self) -> None:
3842
import importlib
3943

40-
embeddings_module = importlib.import_module("langchain.embeddings")
41-
embedding_class_ = getattr(embeddings_module, self.embedding_class)
42-
self.embedding = embedding_class_(**self.embedding_kwargs)
44+
try:
45+
embeddings_module = importlib.import_module("langchain_openai")
46+
embedding_class_ = getattr(embeddings_module, self.embedding_class)
47+
self.embedding = embedding_class_(**self.embedding_kwargs)
48+
except ImportError:
49+
embeddings_module = importlib.import_module("langchain.embeddings")
50+
embedding_class_ = getattr(embeddings_module, self.embedding_class)
51+
self.embedding = embedding_class_(**self.embedding_kwargs)
4352

4453
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray:
4554
return np.array(

0 commit comments

Comments
 (0)