Skip to content

Commit d4fca9d

Browse files
committed
chore: format
1 parent aaaebfa commit d4fca9d

File tree

1 file changed

+81
-102
lines changed

1 file changed

+81
-102
lines changed

backend/open_webui/retrieval/vector/dbs/elasticsearch.py

Lines changed: 81 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,48 @@
11
from elasticsearch import Elasticsearch, BadRequestError
22
from typing import Optional
33
import ssl
4-
from elasticsearch.helpers import bulk,scan
4+
from elasticsearch.helpers import bulk, scan
55
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
66
from open_webui.config import (
77
ELASTICSEARCH_URL,
8-
ELASTICSEARCH_CA_CERTS,
8+
ELASTICSEARCH_CA_CERTS,
99
ELASTICSEARCH_API_KEY,
1010
ELASTICSEARCH_USERNAME,
11-
ELASTICSEARCH_PASSWORD,
11+
ELASTICSEARCH_PASSWORD,
1212
ELASTICSEARCH_CLOUD_ID,
1313
ELASTICSEARCH_INDEX_PREFIX,
1414
SSL_ASSERT_FINGERPRINT,
15-
1615
)
1716

1817

19-
20-
2118
class ElasticsearchClient:
2219
"""
2320
Important:
24-
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
25-
an index for each file but store it as a text field, while seperating to different index
21+
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
22+
an index for each file but store it as a text field, while seperating to different index
2623
baesd on the embedding length.
2724
"""
25+
2826
def __init__(self):
2927
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
3028
self.client = Elasticsearch(
3129
hosts=[ELASTICSEARCH_URL],
3230
ca_certs=ELASTICSEARCH_CA_CERTS,
3331
api_key=ELASTICSEARCH_API_KEY,
3432
cloud_id=ELASTICSEARCH_CLOUD_ID,
35-
basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
36-
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
37-
33+
basic_auth=(
34+
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
35+
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
36+
else None
37+
),
38+
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
3839
)
39-
#Status: works
40-
def _get_index_name(self,dimension:int)->str:
40+
41+
# Status: works
42+
def _get_index_name(self, dimension: int) -> str:
4143
return f"{self.index_prefix}_d{str(dimension)}"
42-
43-
#Status: works
44+
45+
# Status: works
4446
def _scan_result_to_get_result(self, result) -> GetResult:
4547
if not result:
4648
return None
@@ -55,7 +57,7 @@ def _scan_result_to_get_result(self, result) -> GetResult:
5557

5658
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
5759

58-
#Status: works
60+
# Status: works
5961
def _result_to_get_result(self, result) -> GetResult:
6062
if not result["hits"]["hits"]:
6163
return None
@@ -70,7 +72,7 @@ def _result_to_get_result(self, result) -> GetResult:
7072

7173
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
7274

73-
#Status: works
75+
# Status: works
7476
def _result_to_search_result(self, result) -> SearchResult:
7577
ids = []
7678
distances = []
@@ -84,19 +86,21 @@ def _result_to_search_result(self, result) -> SearchResult:
8486
metadatas.append(hit["_source"].get("metadata"))
8587

8688
return SearchResult(
87-
ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
89+
ids=[ids],
90+
distances=[distances],
91+
documents=[documents],
92+
metadatas=[metadatas],
8893
)
89-
#Status: works
94+
95+
# Status: works
9096
def _create_index(self, dimension: int):
9197
body = {
9298
"mappings": {
9399
"dynamic_templates": [
94100
{
95-
"strings": {
96-
"match_mapping_type": "string",
97-
"mapping": {
98-
"type": "keyword"
99-
}
101+
"strings": {
102+
"match_mapping_type": "string",
103+
"mapping": {"type": "keyword"},
100104
}
101105
}
102106
],
@@ -111,68 +115,52 @@ def _create_index(self, dimension: int):
111115
},
112116
"text": {"type": "text"},
113117
"metadata": {"type": "object"},
114-
}
118+
},
115119
}
116120
}
117121
self.client.indices.create(index=self._get_index_name(dimension), body=body)
118-
#Status: works
122+
123+
# Status: works
119124

120125
def _create_batches(self, items: list[VectorItem], batch_size=100):
121126
for i in range(0, len(items), batch_size):
122-
yield items[i : min(i + batch_size,len(items))]
127+
yield items[i : min(i + batch_size, len(items))]
123128

124-
#Status: works
125-
def has_collection(self,collection_name) -> bool:
129+
# Status: works
130+
def has_collection(self, collection_name) -> bool:
126131
query_body = {"query": {"bool": {"filter": []}}}
127-
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
132+
query_body["query"]["bool"]["filter"].append(
133+
{"term": {"collection": collection_name}}
134+
)
128135

129136
try:
130-
result = self.client.count(
131-
index=f"{self.index_prefix}*",
132-
body=query_body
133-
)
134-
135-
return result.body["count"]>0
137+
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
138+
139+
return result.body["count"] > 0
136140
except Exception as e:
137141
return None
138-
139142

140-
141143
def delete_collection(self, collection_name: str):
142-
query = {
143-
"query": {
144-
"term": {"collection": collection_name}
145-
}
146-
}
144+
query = {"query": {"term": {"collection": collection_name}}}
147145
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
148-
#Status: works
146+
147+
# Status: works
149148
def search(
150149
self, collection_name: str, vectors: list[list[float]], limit: int
151150
) -> Optional[SearchResult]:
152151
query = {
153152
"size": limit,
154-
"_source": [
155-
"text",
156-
"metadata"
157-
],
153+
"_source": ["text", "metadata"],
158154
"query": {
159155
"script_score": {
160156
"query": {
161-
"bool": {
162-
"filter": [
163-
{
164-
"term": {
165-
"collection": collection_name
166-
}
167-
}
168-
]
169-
}
157+
"bool": {"filter": [{"term": {"collection": collection_name}}]}
170158
},
171159
"script": {
172160
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
173161
"params": {
174162
"vector": vectors[0]
175-
}, # Assuming single query vector
163+
}, # Assuming single query vector
176164
},
177165
}
178166
},
@@ -183,7 +171,8 @@ def search(
183171
)
184172

185173
return self._result_to_search_result(result)
186-
#Status: only tested halfwat
174+
175+
# Status: only tested halfwat
187176
def query(
188177
self, collection_name: str, filter: dict, limit: Optional[int] = None
189178
) -> Optional[GetResult]:
@@ -197,7 +186,9 @@ def query(
197186

198187
for field, value in filter.items():
199188
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
200-
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
189+
query_body["query"]["bool"]["filter"].append(
190+
{"term": {"collection": collection_name}}
191+
)
201192
size = limit if limit else 10
202193

203194
try:
@@ -206,59 +197,53 @@ def query(
206197
body=query_body,
207198
size=size,
208199
)
209-
200+
210201
return self._result_to_get_result(result)
211202

212203
except Exception as e:
213204
return None
214-
#Status: works
215-
def _has_index(self,dimension:int):
216-
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
217205

206+
# Status: works
207+
def _has_index(self, dimension: int):
208+
return self.client.indices.exists(
209+
index=self._get_index_name(dimension=dimension)
210+
)
218211

219212
def get_or_create_index(self, dimension: int):
220213
if not self._has_index(dimension=dimension):
221214
self._create_index(dimension=dimension)
222-
#Status: works
215+
216+
# Status: works
223217
def get(self, collection_name: str) -> Optional[GetResult]:
224218
# Get all the items in the collection.
225219
query = {
226-
"query": {
227-
"bool": {
228-
"filter": [
229-
{
230-
"term": {
231-
"collection": collection_name
232-
}
233-
}
234-
]
235-
}
236-
}, "_source": ["text", "metadata"]}
220+
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
221+
"_source": ["text", "metadata"],
222+
}
237223
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
238-
224+
239225
return self._scan_result_to_get_result(results)
240226

241-
#Status: works
227+
# Status: works
242228
def insert(self, collection_name: str, items: list[VectorItem]):
243229
if not self._has_index(dimension=len(items[0]["vector"])):
244230
self._create_index(dimension=len(items[0]["vector"]))
245231

246-
247232
for batch in self._create_batches(items):
248233
actions = [
249-
{
250-
"_index":self._get_index_name(dimension=len(items[0]["vector"])),
251-
"_id": item["id"],
252-
"_source": {
253-
"collection": collection_name,
254-
"vector": item["vector"],
255-
"text": item["text"],
256-
"metadata": item["metadata"],
257-
},
258-
}
234+
{
235+
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
236+
"_id": item["id"],
237+
"_source": {
238+
"collection": collection_name,
239+
"vector": item["vector"],
240+
"text": item["text"],
241+
"metadata": item["metadata"],
242+
},
243+
}
259244
for item in batch
260245
]
261-
bulk(self.client,actions)
246+
bulk(self.client, actions)
262247

263248
# Upsert documents using the update API with doc_as_upsert=True.
264249
def upsert(self, collection_name: str, items: list[VectorItem]):
@@ -280,8 +265,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]):
280265
}
281266
for item in batch
282267
]
283-
bulk(self.client,actions)
284-
268+
bulk(self.client, actions)
285269

286270
# Delete specific documents from a collection by filtering on both collection and document IDs.
287271
def delete(
@@ -292,21 +276,16 @@ def delete(
292276
):
293277

294278
query = {
295-
"query": {
296-
"bool": {
297-
"filter": [
298-
{"term": {"collection": collection_name}}
299-
]
300-
}
301-
}
279+
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
302280
}
303-
#logic based on chromaDB
281+
# logic based on chromaDB
304282
if ids:
305283
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
306284
elif filter:
307285
for field, value in filter.items():
308-
query["query"]["bool"]["filter"].append({"term": {f"metadata.{field}": value}})
309-
286+
query["query"]["bool"]["filter"].append(
287+
{"term": {f"metadata.{field}": value}}
288+
)
310289

311290
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
312291

0 commit comments

Comments
 (0)