Skip to content

Commit ae0c5f0

Browse files
Add support for LangChain embeddings in Object API (#280)
This adds support for LangChain embeddings in Object API. It also adds support for passing environment variables to the ingestion execution.
1 parent 86b55fd commit ae0c5f0

File tree

5 files changed

+91
-2
lines changed

5 files changed

+91
-2
lines changed

apis/python/examples/object_api/text_search_documents.ipynb

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"import tiledb\n",
2323
"from tiledb.vector_search.object_api import object_index\n",
2424
"from tiledb.vector_search.object_readers import DirectoryTextReader\n",
25-
"from tiledb.vector_search.embeddings import SentenceTransformersEmbedding\n",
25+
"from tiledb.vector_search.embeddings import SentenceTransformersEmbedding, LangChainEmbedding\n",
2626
"\n",
2727
"dataset = \"documents\"\n",
2828
"base_uri = f\"/tmp/{dataset}_demo\"\n",
@@ -87,6 +87,20 @@
8787
" text_splitter_kwargs={\"chunk_size\":1000}\n",
8888
" )\n",
8989
"embedding = SentenceTransformersEmbedding(model_name_or_path='BAAI/bge-small-en-v1.5', dimensions=384)\n",
90+
"# embedding = LangChainEmbedding(\n",
91+
"# dimensions=384, \n",
92+
"# embedding_class=\"HuggingFaceEmbeddings\", \n",
93+
"# embedding_kwargs={\n",
94+
"# \"model_name\": 'BAAI/bge-small-en-v1.5', \n",
95+
"# }\n",
96+
"# )\n",
97+
"# embedding = LangChainEmbedding(\n",
98+
"# dimensions=1536, \n",
99+
"# embedding_class=\"OpenAIEmbeddings\", \n",
100+
"# embedding_kwargs={\n",
101+
"# \"model\": 'text-embedding-ada-002', \n",
102+
"# }\n",
103+
"# )\n",
90104
"index = object_index.create(\n",
91105
" uri=index_uri,\n",
92106
" index_type=\"IVF_FLAT\",\n",

apis/python/src/tiledb/vector_search/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .image_resnetv2_embedding import ImageResNetV2Embedding
2+
from .langchain_embedding import LangChainEmbedding
23
from .object_embedding import ObjectEmbedding
34
from .random_embedding import RandomEmbedding
45
from .sentence_transformers_embedding import SentenceTransformersEmbedding
@@ -10,4 +11,5 @@
1011
"ImageResNetV2Embedding",
1112
"RandomEmbedding",
1213
"SentenceTransformersEmbedding",
14+
"LangChainEmbedding",
1315
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Dict, Optional, OrderedDict
2+
3+
import numpy as np
4+
5+
# from tiledb.vector_search.embeddings import ObjectEmbedding
6+
7+
8+
# class LangChainEmbedding(ObjectEmbedding):
9+
class LangChainEmbedding():
10+
"""
11+
Embedding functions from `langchain.embeddings` package.
12+
"""
13+
14+
def __init__(
15+
self,
16+
dimensions: int,
17+
embedding_class: str = "OpenAIEmbeddings",
18+
embedding_kwargs: Optional[Dict] = None,
19+
):
20+
self.dim_num = dimensions
21+
self.embedding_class = embedding_class
22+
self.embedding_kwargs = embedding_kwargs
23+
24+
def init_kwargs(self) -> Dict:
25+
return {
26+
"dimensions": self.dim_num,
27+
"embedding_class": self.embedding_class,
28+
"embedding_kwargs": self.embedding_kwargs,
29+
}
30+
31+
def dimensions(self) -> int:
32+
return self.dim_num
33+
34+
def vector_type(self) -> np.dtype:
35+
return np.float32
36+
37+
def load(self) -> None:
38+
import importlib
39+
40+
embeddings_module = importlib.import_module("langchain.embeddings")
41+
embedding_class_ = getattr(embeddings_module, self.embedding_class)
42+
self.embedding = embedding_class_(**self.embedding_kwargs)
43+
44+
def embed(self, objects: OrderedDict, metadata: OrderedDict) -> np.ndarray:
45+
return np.array(
46+
self.embedding.embed_documents(objects["text"]), dtype=np.float32
47+
)

apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def ingest_embeddings_with_driver(
2525
vector_indexing_mode: Mode = Mode.LOCAL,
2626
config: Optional[Mapping[str, Any]] = None,
2727
namespace: Optional[str] = None,
28+
environment_variables: Dict = {},
2829
**kwargs,
2930
):
3031
def ingest_embeddings(
@@ -45,6 +46,7 @@ def ingest_embeddings(
4546
vector_indexing_mode: Mode = Mode.LOCAL,
4647
config: Optional[Mapping[str, Any]] = None,
4748
namespace: Optional[str] = None,
49+
environment_variables: Dict = {},
4850
**kwargs,
4951
):
5052
import tiledb
@@ -133,6 +135,7 @@ def compute_embeddings_udf(
133135
trace_id: Optional[str] = None,
134136
config: Optional[Mapping[str, Any]] = None,
135137
extra_worker_modules: Optional[List[str]] = None,
138+
environment_variables: Dict = {},
136139
):
137140
def install_extra_driver_modules():
138141
if extra_worker_modules is not None:
@@ -150,6 +153,8 @@ def install_extra_driver_modules():
150153

151154
install_extra_driver_modules()
152155

156+
import os
157+
153158
import numpy as np
154159

155160
import tiledb
@@ -187,6 +192,8 @@ def instantiate_object(code, class_name, **kwargs):
187192
class_name=object_embedding_class_name,
188193
**object_embedding_kwargs,
189194
)
195+
for var, val in environment_variables.items():
196+
os.environ[var] = val
190197
with tiledb.scope_ctx(ctx_or_config=config):
191198
logger.debug("Loading model...")
192199
object_embedding.load()
@@ -263,6 +270,7 @@ def create_dag(
263270
trace_id: Optional[str] = None,
264271
embeddings_generation_mode: Mode = Mode.LOCAL,
265272
config: Optional[Mapping[str, Any]] = None,
273+
environment_variables: Dict = {},
266274
namespace: Optional[str] = None,
267275
) -> dag.DAG:
268276
if embeddings_generation_mode == Mode.BATCH:
@@ -327,6 +335,7 @@ def create_dag(
327335
verbose=verbose,
328336
trace_id=trace_id,
329337
config=config,
338+
environment_variables=environment_variables,
330339
extra_worker_modules=extra_udf_worker_modules,
331340
name="generate-embeddings-" + str(task_id),
332341
resources=worker_resources,
@@ -340,6 +349,10 @@ def create_dag(
340349
# End internal function definitions
341350
# --------------------------------------------------------------------
342351

352+
import os
353+
354+
for var, val in environment_variables.items():
355+
os.environ[var] = val
343356
with tiledb.scope_ctx(ctx_or_config=config):
344357
logger = setup(config, verbose)
345358
logger.debug("Generating embeddings")
@@ -348,7 +361,11 @@ def create_dag(
348361

349362
from tiledb.vector_search.object_api import object_index
350363

351-
ob_index = object_index.ObjectIndex(object_index_uri, config=config)
364+
ob_index = object_index.ObjectIndex(
365+
object_index_uri,
366+
config=config,
367+
environment_variables=environment_variables,
368+
)
352369
partitions = ob_index.object_reader.get_partitions(**kwargs)
353370
object_partitions = len(partitions)
354371
object_partitions_per_worker = 1
@@ -394,6 +411,7 @@ def create_dag(
394411
trace_id=trace_id,
395412
embeddings_generation_mode=embeddings_generation_mode,
396413
config=config,
414+
environment_variables=environment_variables,
397415
namespace=namespace,
398416
)
399417
logger.debug("Submitting ingestion graph")
@@ -465,6 +483,7 @@ def submit_local(d, func, *args, **kwargs):
465483
embeddings_generation_mode=embeddings_generation_mode,
466484
vector_indexing_mode=vector_indexing_mode,
467485
config=config,
486+
environment_variables=environment_variables,
468487
**kwargs,
469488
name="ingest-embeddings-driver",
470489
resources={"cpu": "1", "memory": "4Gi"}

apis/python/src/tiledb/vector_search/object_api/object_index.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ def __init__(
2828
timestamp=None,
2929
load_embedding: bool = True,
3030
load_metadata_in_memory: bool = True,
31+
environment_variables: Dict = {},
3132
**kwargs,
3233
):
34+
import os
35+
36+
for var, val in environment_variables.items():
37+
os.environ[var] = val
3338
with tiledb.scope_ctx(ctx_or_config=config):
3439
self.uri = uri
3540
self.config = config
@@ -235,6 +240,7 @@ def update_index(
235240
vector_indexing_mode: Mode = Mode.LOCAL,
236241
config: Optional[Mapping[str, Any]] = None,
237242
namespace: Optional[str] = None,
243+
environment_variables: Dict = {},
238244
**kwargs,
239245
):
240246
embeddings_array_name = storage_formats[self.index.storage_version][
@@ -268,6 +274,7 @@ def update_index(
268274
embeddings_generation_mode=embeddings_generation_mode,
269275
vector_indexing_mode=vector_indexing_mode,
270276
config=config,
277+
environment_variables=environment_variables,
271278
**kwargs,
272279
)
273280

0 commit comments

Comments
 (0)