88from pathlib import Path
99from typing import Union , List , Optional , Dict , Generator
1010from tqdm .auto import tqdm
11+ import warnings
1112
1213try :
1314 import faiss
@@ -37,7 +38,8 @@ class FAISSDocumentStore(SQLDocumentStore):
3738 def __init__ (
3839 self ,
3940 sql_url : str = "sqlite:///faiss_document_store.db" ,
40- vector_dim : int = 768 ,
41+ vector_dim : int = None ,
42+ embedding_dim : int = 768 ,
4143 faiss_index_factory_str : str = "Flat" ,
4244 faiss_index : Optional ["faiss.swigfaiss.Index" ] = None ,
4345 return_embedding : bool = False ,
@@ -53,7 +55,8 @@ def __init__(
5355 """
5456 :param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
5557 deployment, Postgres is recommended.
56- :param vector_dim: the embedding vector size.
58+ :param vector_dim: Deprecated. Use embedding_dim instead.
59+ :param embedding_dim: The embedding vector size. Default: 768.
5760 :param faiss_index_factory_str: Create a new FAISS index of the specified type.
5861 The type is determined from the given string following the conventions
5962 of the original FAISS index factory.
@@ -75,7 +78,7 @@ def __init__(
7578 :param index: Name of index in document store to use.
7679 :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
7780 more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
78- In both cases, the returned values in Document.score are normalized to be in range [0,1]:
81+ In both cases, the returned values in Document.score are normalized to be in range [0,1]:
7982 For `dot_product`: expit(np.asarray(raw_score / 100))
8083 FOr `cosine`: (raw_score + 1) / 2
8184 :param embedding_field: Name of field containing an embedding vector.
@@ -89,7 +92,7 @@ def __init__(
8992 exists.
9093 :param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
9194 If specified no other params besides faiss_config_path must be specified.
92- :param faiss_config_path: Stored FAISS initial configuration parameters.
95+ :param faiss_config_path: Stored FAISS initial configuration parameters.
9396 Can be created via calling `save()`
9497 """
9598 # special case if we want to load an existing index from disk
@@ -103,14 +106,15 @@ def __init__(
103106
104107 # save init parameters to enable export of component config as YAML
105108 self .set_config (
106- sql_url = sql_url ,
107- vector_dim = vector_dim ,
109+ sql_url = sql_url ,
110+ vector_dim = vector_dim ,
111+ embedding_dim = embedding_dim ,
108112 faiss_index_factory_str = faiss_index_factory_str ,
109113 return_embedding = return_embedding ,
110- duplicate_documents = duplicate_documents ,
111- index = index ,
114+ duplicate_documents = duplicate_documents ,
115+ index = index ,
112116 similarity = similarity ,
113- embedding_field = embedding_field ,
117+ embedding_field = embedding_field ,
114118 progress_bar = progress_bar
115119 )
116120
@@ -124,14 +128,20 @@ def __init__(
124128 raise ValueError ("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
125129 "Please set similarity to one of the above." )
126130
127- self .vector_dim = vector_dim
131+ if vector_dim is not None :
132+ warnings .warn ("The 'vector_dim' parameter is deprecated, "
133+ "use 'embedding_dim' instead." , DeprecationWarning , 2 )
134+ self .embedding_dim = vector_dim
135+ else :
136+ self .embedding_dim = embedding_dim
137+
128138 self .faiss_index_factory_str = faiss_index_factory_str
129139 self .faiss_indexes : Dict [str , faiss .swigfaiss .Index ] = {}
130140 if faiss_index :
131141 self .faiss_indexes [index ] = faiss_index
132142 else :
133143 self .faiss_indexes [index ] = self ._create_new_index (
134- vector_dim = self .vector_dim ,
144+ embedding_dim = self .embedding_dim ,
135145 index_factory = faiss_index_factory_str ,
136146 metric_type = self .metric_type ,
137147 ** kwargs
@@ -158,7 +168,7 @@ def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs:
158168 if param .name not in allowed_params and param .default != locals [param .name ]:
159169 invalid_param_set = True
160170 break
161-
171+
162172 if invalid_param_set or len (kwargs ) > 0 :
163173 raise ValueError ("if faiss_index_path is passed no other params besides faiss_config_path are allowed." )
164174
@@ -172,20 +182,20 @@ def _validate_index_sync(self):
172182 "configuration file correctly points to the same database that "
173183 "was used when creating the original index." )
174184
175- def _create_new_index (self , vector_dim : int , metric_type , index_factory : str = "Flat" , ** kwargs ):
185+ def _create_new_index (self , embedding_dim : int , metric_type , index_factory : str = "Flat" , ** kwargs ):
176186 if index_factory == "HNSW" :
177187 # faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
178188 # defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
179189 n_links = kwargs .get ("n_links" , 64 )
180- index = faiss .IndexHNSWFlat (vector_dim , n_links , metric_type )
190+ index = faiss .IndexHNSWFlat (embedding_dim , n_links , metric_type )
181191 index .hnsw .efSearch = kwargs .get ("efSearch" , 20 )#20
182192 index .hnsw .efConstruction = kwargs .get ("efConstruction" , 80 )#80
183193 if "ivf" in index_factory .lower (): # enable reconstruction of vectors for inverted index
184194 self .faiss_indexes [index ].set_direct_map_type (faiss .DirectMap .Hashtable )
185195
186196 logger .info (f"HNSW params: n_links: { n_links } , efSearch: { index .hnsw .efSearch } , efConstruction: { index .hnsw .efConstruction } " )
187197 else :
188- index = faiss .index_factory (vector_dim , index_factory , metric_type )
198+ index = faiss .index_factory (embedding_dim , index_factory , metric_type )
189199 return index
190200
191201 def write_documents (self , documents : Union [List [dict ], List [Document ]], index : Optional [str ] = None ,
@@ -217,7 +227,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
217227
218228 if not self .faiss_indexes .get (index ):
219229 self .faiss_indexes [index ] = self ._create_new_index (
220- vector_dim = self .vector_dim ,
230+ embedding_dim = self .embedding_dim ,
221231 index_factory = self .faiss_index_factory_str ,
222232 metric_type = faiss .METRIC_INNER_PRODUCT ,
223233 )
@@ -544,7 +554,7 @@ def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Pa
544554 :param config_path: Path to save the initial configuration parameters to.
545555 Defaults to the same as the file path, save the extension (.json).
546556 This file contains all the parameters passed to FAISSDocumentStore()
547- at creation time (for example the SQL path, vector_dim , etc), and will be
557+ at creation time (for example the SQL path, embedding_dim , etc), and will be
548558 used by the `load` method to restore the index with the appropriate configuration.
549559 :return: None
550560 """
@@ -574,7 +584,7 @@ def _load_init_params_from_config(self, index_path: Union[str, Path], config_pat
574584
575585 # Add other init params to override the ones defined in the init params file
576586 init_params ["faiss_index" ] = faiss_index
577- init_params ["vector_dim " ] = faiss_index .d
587+ init_params ["embedding_dim " ] = faiss_index .d
578588
579589 return init_params
580590
0 commit comments