11from dataclasses import dataclass
2- from typing import TYPE_CHECKING , List , Optional
2+ from typing import TYPE_CHECKING , Iterable , List , Optional , cast
33
44import numpy as np
55from pydantic import Field , SecretStr
99from unstructured .utils import requires_dependencies
1010
1111if TYPE_CHECKING :
12- from langchain_voyageai import VoyageAIEmbeddings
12+ from voyageai import Client
13+
14+ DEFAULT_VOYAGE_2_BATCH_SIZE = 72
15+ DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
16+ DEFAULT_VOYAGE_3_BATCH_SIZE = 10
17+ DEFAULT_BATCH_SIZE = 7
1318
1419
1520class VoyageAIEmbeddingConfig (EmbeddingConfig ):
1621 api_key : SecretStr
1722 model_name : str
23+ show_progress_bar : bool = False
1824 batch_size : Optional [int ] = Field (default = None )
1925 truncation : Optional [bool ] = Field (default = None )
26+ output_dimension : Optional [int ] = Field (default = None )
2027
2128 @requires_dependencies (
22- ["langchain" , "langchain_voyageai " ],
29+ ["voyageai " ],
2330 extras = "embed-voyageai" ,
2431 )
25- def get_client (self ) -> "VoyageAIEmbeddings" :
26- """Creates a Langchain VoyageAI python client to embed elements."""
27- from langchain_voyageai import VoyageAIEmbeddings
28-
29- return VoyageAIEmbeddings (
30- voyage_api_key = self .api_key ,
31- model = self .model_name ,
32- batch_size = self .batch_size ,
33- truncation = self .truncation ,
32+ def get_client (self ) -> "Client" :
33+ """Creates a VoyageAI python client to embed elements."""
34+ from voyageai import Client
35+
36+ return Client (
37+ api_key = self .api_key .get_secret_value (),
3438 )
3539
40+ def get_batch_size (self ):
41+ if self .batch_size is None :
42+ if self .model_name in ["voyage-2" , "voyage-02" ]:
43+ self .batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE
44+ elif self .model_name == "voyage-3-lite" :
45+ self .batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE
46+ elif self .model_name == "voyage-3" :
47+ self .batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE
48+ else :
49+ self .batch_size = DEFAULT_BATCH_SIZE
50+ return self .batch_size
51+
3652
3753@dataclass
3854class VoyageAIEmbeddingEncoder (BaseEmbeddingEncoder ):
@@ -56,12 +72,29 @@ def is_unit_vector(self) -> bool:
5672
5773 def embed_documents (self , elements : List [Element ]) -> List [Element ]:
5874 client = self .config .get_client ()
59- embeddings = client .embed_documents ([str (e ) for e in elements ])
75+ embeddings : List [List [float ]] = []
76+
77+ _iter = self ._get_batch_iterator (elements )
78+ for i in _iter :
79+ r = client .embed (
80+ texts = [str (e ) for e in elements [i : i + self .config .get_batch_size ()]],
81+ model = self .config .model_name ,
82+ input_type = "document" ,
83+ truncation = self .config .truncation ,
84+ output_dimension = self .config .output_dimension ,
85+ ).embeddings
86+ embeddings .extend (cast (Iterable [List [float ]], r ))
6087 return self ._add_embeddings_to_elements (elements , embeddings )
6188
6289 def embed_query (self , query : str ) -> List [float ]:
6390 client = self .config .get_client ()
64- return client .embed_query (query )
91+ return client .embed (
92+ texts = [query ],
93+ model = self .config .model_name ,
94+ input_type = "query" ,
95+ truncation = self .config .truncation ,
96+ output_dimension = self .config .output_dimension ,
97+ ).embeddings [0 ]
6598
6699 @staticmethod
67100 def _add_embeddings_to_elements (elements , embeddings ) -> List [Element ]:
@@ -71,3 +104,19 @@ def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
71104 element .embeddings = embeddings [i ]
72105 elements_w_embedding .append (element )
73106 return elements
107+
108+ def _get_batch_iterator (self , elements : List [Element ]) -> Iterable :
109+ if self .config .show_progress_bar :
110+ try :
111+ from tqdm .auto import tqdm # type: ignore
112+ except ImportError as e :
113+ raise ImportError (
114+ "Must have tqdm installed if `show_progress_bar` is set to True. "
115+ "Please install with `pip install tqdm`."
116+ ) from e
117+
118+ _iter = tqdm (range (0 , len (elements ), self .config .get_batch_size ()))
119+ else :
120+ _iter = range (0 , len (elements ), self .config .get_batch_size ()) # type: ignore
121+
122+ return _iter
0 commit comments