11from elasticsearch import Elasticsearch , BadRequestError
22from typing import Optional
33import ssl
4- from elasticsearch .helpers import bulk ,scan
4+ from elasticsearch .helpers import bulk , scan
55from open_webui .retrieval .vector .main import VectorItem , SearchResult , GetResult
66from open_webui .config import (
77 ELASTICSEARCH_URL ,
8- ELASTICSEARCH_CA_CERTS ,
8+ ELASTICSEARCH_CA_CERTS ,
99 ELASTICSEARCH_API_KEY ,
1010 ELASTICSEARCH_USERNAME ,
11- ELASTICSEARCH_PASSWORD ,
11+ ELASTICSEARCH_PASSWORD ,
1212 ELASTICSEARCH_CLOUD_ID ,
1313 ELASTICSEARCH_INDEX_PREFIX ,
1414 SSL_ASSERT_FINGERPRINT ,
15-
1615)
1716
1817
19-
20-
2118class ElasticsearchClient :
2219 """
2320 Important:
24- in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
25- an index for each file but store it as a text field, while seperating to different index
21+ in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
22+ an index for each file but store it as a text field, while seperating to different index
2623 baesd on the embedding length.
2724 """
25+
2826 def __init__ (self ):
2927 self .index_prefix = ELASTICSEARCH_INDEX_PREFIX
3028 self .client = Elasticsearch (
3129 hosts = [ELASTICSEARCH_URL ],
3230 ca_certs = ELASTICSEARCH_CA_CERTS ,
3331 api_key = ELASTICSEARCH_API_KEY ,
3432 cloud_id = ELASTICSEARCH_CLOUD_ID ,
35- basic_auth = (ELASTICSEARCH_USERNAME ,ELASTICSEARCH_PASSWORD ) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None ,
36- ssl_assert_fingerprint = SSL_ASSERT_FINGERPRINT
37-
33+ basic_auth = (
34+ (ELASTICSEARCH_USERNAME , ELASTICSEARCH_PASSWORD )
35+ if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
36+ else None
37+ ),
38+ ssl_assert_fingerprint = SSL_ASSERT_FINGERPRINT ,
3839 )
39- #Status: works
40- def _get_index_name (self ,dimension :int )-> str :
40+
41+ # Status: works
42+ def _get_index_name (self , dimension : int ) -> str :
4143 return f"{ self .index_prefix } _d{ str (dimension )} "
42-
43- #Status: works
44+
45+ # Status: works
4446 def _scan_result_to_get_result (self , result ) -> GetResult :
4547 if not result :
4648 return None
@@ -55,7 +57,7 @@ def _scan_result_to_get_result(self, result) -> GetResult:
5557
5658 return GetResult (ids = [ids ], documents = [documents ], metadatas = [metadatas ])
5759
58- #Status: works
60+ # Status: works
5961 def _result_to_get_result (self , result ) -> GetResult :
6062 if not result ["hits" ]["hits" ]:
6163 return None
@@ -70,7 +72,7 @@ def _result_to_get_result(self, result) -> GetResult:
7072
7173 return GetResult (ids = [ids ], documents = [documents ], metadatas = [metadatas ])
7274
73- #Status: works
75+ # Status: works
7476 def _result_to_search_result (self , result ) -> SearchResult :
7577 ids = []
7678 distances = []
@@ -84,19 +86,21 @@ def _result_to_search_result(self, result) -> SearchResult:
8486 metadatas .append (hit ["_source" ].get ("metadata" ))
8587
8688 return SearchResult (
87- ids = [ids ], distances = [distances ], documents = [documents ], metadatas = [metadatas ]
89+ ids = [ids ],
90+ distances = [distances ],
91+ documents = [documents ],
92+ metadatas = [metadatas ],
8893 )
89- #Status: works
94+
95+ # Status: works
9096 def _create_index (self , dimension : int ):
9197 body = {
9298 "mappings" : {
9399 "dynamic_templates" : [
94100 {
95- "strings" : {
96- "match_mapping_type" : "string" ,
97- "mapping" : {
98- "type" : "keyword"
99- }
101+ "strings" : {
102+ "match_mapping_type" : "string" ,
103+ "mapping" : {"type" : "keyword" },
100104 }
101105 }
102106 ],
@@ -111,68 +115,52 @@ def _create_index(self, dimension: int):
111115 },
112116 "text" : {"type" : "text" },
113117 "metadata" : {"type" : "object" },
114- }
118+ },
115119 }
116120 }
117121 self .client .indices .create (index = self ._get_index_name (dimension ), body = body )
118- #Status: works
122+
123+ # Status: works
119124
120125 def _create_batches (self , items : list [VectorItem ], batch_size = 100 ):
121126 for i in range (0 , len (items ), batch_size ):
122- yield items [i : min (i + batch_size ,len (items ))]
127+ yield items [i : min (i + batch_size , len (items ))]
123128
124- #Status: works
125- def has_collection (self ,collection_name ) -> bool :
129+ # Status: works
130+ def has_collection (self , collection_name ) -> bool :
126131 query_body = {"query" : {"bool" : {"filter" : []}}}
127- query_body ["query" ]["bool" ]["filter" ].append ({"term" : {"collection" : collection_name }})
132+ query_body ["query" ]["bool" ]["filter" ].append (
133+ {"term" : {"collection" : collection_name }}
134+ )
128135
129136 try :
130- result = self .client .count (
131- index = f"{ self .index_prefix } *" ,
132- body = query_body
133- )
134-
135- return result .body ["count" ]> 0
137+ result = self .client .count (index = f"{ self .index_prefix } *" , body = query_body )
138+
139+ return result .body ["count" ] > 0
136140 except Exception as e :
137141 return None
138-
139142
140-
141143 def delete_collection (self , collection_name : str ):
142- query = {
143- "query" : {
144- "term" : {"collection" : collection_name }
145- }
146- }
144+ query = {"query" : {"term" : {"collection" : collection_name }}}
147145 self .client .delete_by_query (index = f"{ self .index_prefix } *" , body = query )
148- #Status: works
146+
147+ # Status: works
149148 def search (
150149 self , collection_name : str , vectors : list [list [float ]], limit : int
151150 ) -> Optional [SearchResult ]:
152151 query = {
153152 "size" : limit ,
154- "_source" : [
155- "text" ,
156- "metadata"
157- ],
153+ "_source" : ["text" , "metadata" ],
158154 "query" : {
159155 "script_score" : {
160156 "query" : {
161- "bool" : {
162- "filter" : [
163- {
164- "term" : {
165- "collection" : collection_name
166- }
167- }
168- ]
169- }
157+ "bool" : {"filter" : [{"term" : {"collection" : collection_name }}]}
170158 },
171159 "script" : {
172160 "source" : "cosineSimilarity(params.vector, 'vector') + 1.0" ,
173161 "params" : {
174162 "vector" : vectors [0 ]
175- }, # Assuming single query vector
163+ }, # Assuming single query vector
176164 },
177165 }
178166 },
@@ -183,7 +171,8 @@ def search(
183171 )
184172
185173 return self ._result_to_search_result (result )
186- #Status: only tested halfwat
174+
175+ # Status: only tested halfwat
187176 def query (
188177 self , collection_name : str , filter : dict , limit : Optional [int ] = None
189178 ) -> Optional [GetResult ]:
@@ -197,7 +186,9 @@ def query(
197186
198187 for field , value in filter .items ():
199188 query_body ["query" ]["bool" ]["filter" ].append ({"term" : {field : value }})
200- query_body ["query" ]["bool" ]["filter" ].append ({"term" : {"collection" : collection_name }})
189+ query_body ["query" ]["bool" ]["filter" ].append (
190+ {"term" : {"collection" : collection_name }}
191+ )
201192 size = limit if limit else 10
202193
203194 try :
@@ -206,59 +197,53 @@ def query(
206197 body = query_body ,
207198 size = size ,
208199 )
209-
200+
210201 return self ._result_to_get_result (result )
211202
212203 except Exception as e :
213204 return None
214- #Status: works
215- def _has_index (self ,dimension :int ):
216- return self .client .indices .exists (index = self ._get_index_name (dimension = dimension ))
217205
206+ # Status: works
207+ def _has_index (self , dimension : int ):
208+ return self .client .indices .exists (
209+ index = self ._get_index_name (dimension = dimension )
210+ )
218211
219212 def get_or_create_index (self , dimension : int ):
220213 if not self ._has_index (dimension = dimension ):
221214 self ._create_index (dimension = dimension )
222- #Status: works
215+
216+ # Status: works
223217 def get (self , collection_name : str ) -> Optional [GetResult ]:
224218 # Get all the items in the collection.
225219 query = {
226- "query" : {
227- "bool" : {
228- "filter" : [
229- {
230- "term" : {
231- "collection" : collection_name
232- }
233- }
234- ]
235- }
236- }, "_source" : ["text" , "metadata" ]}
220+ "query" : {"bool" : {"filter" : [{"term" : {"collection" : collection_name }}]}},
221+ "_source" : ["text" , "metadata" ],
222+ }
237223 results = list (scan (self .client , index = f"{ self .index_prefix } *" , query = query ))
238-
224+
239225 return self ._scan_result_to_get_result (results )
240226
241- #Status: works
227+ # Status: works
242228 def insert (self , collection_name : str , items : list [VectorItem ]):
243229 if not self ._has_index (dimension = len (items [0 ]["vector" ])):
244230 self ._create_index (dimension = len (items [0 ]["vector" ]))
245231
246-
247232 for batch in self ._create_batches (items ):
248233 actions = [
249- {
250- "_index" :self ._get_index_name (dimension = len (items [0 ]["vector" ])),
251- "_id" : item ["id" ],
252- "_source" : {
253- "collection" : collection_name ,
254- "vector" : item ["vector" ],
255- "text" : item ["text" ],
256- "metadata" : item ["metadata" ],
257- },
258- }
234+ {
235+ "_index" : self ._get_index_name (dimension = len (items [0 ]["vector" ])),
236+ "_id" : item ["id" ],
237+ "_source" : {
238+ "collection" : collection_name ,
239+ "vector" : item ["vector" ],
240+ "text" : item ["text" ],
241+ "metadata" : item ["metadata" ],
242+ },
243+ }
259244 for item in batch
260245 ]
261- bulk (self .client ,actions )
246+ bulk (self .client , actions )
262247
263248 # Upsert documents using the update API with doc_as_upsert=True.
264249 def upsert (self , collection_name : str , items : list [VectorItem ]):
@@ -280,8 +265,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]):
280265 }
281266 for item in batch
282267 ]
283- bulk (self .client ,actions )
284-
268+ bulk (self .client , actions )
285269
286270 # Delete specific documents from a collection by filtering on both collection and document IDs.
287271 def delete (
@@ -292,21 +276,16 @@ def delete(
292276 ):
293277
294278 query = {
295- "query" : {
296- "bool" : {
297- "filter" : [
298- {"term" : {"collection" : collection_name }}
299- ]
300- }
301- }
279+ "query" : {"bool" : {"filter" : [{"term" : {"collection" : collection_name }}]}}
302280 }
303- # logic based on chromaDB
281+ # logic based on chromaDB
304282 if ids :
305283 query ["query" ]["bool" ]["filter" ].append ({"terms" : {"_id" : ids }})
306284 elif filter :
307285 for field , value in filter .items ():
308- query ["query" ]["bool" ]["filter" ].append ({"term" : {f"metadata.{ field } " : value }})
309-
286+ query ["query" ]["bool" ]["filter" ].append (
287+ {"term" : {f"metadata.{ field } " : value }}
288+ )
310289
311290 self .client .delete_by_query (index = f"{ self .index_prefix } *" , body = query )
312291
0 commit comments