Skip to content

Commit ce7cf62

Browse files
committed
refac: dedup
1 parent 0a4dbf7 commit ce7cf62

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

backend/open_webui/retrieval/utils.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import asyncio
77
import requests
8+
import hashlib
89

910
from huggingface_hub import snapshot_download
1011
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@@ -175,46 +176,41 @@ def merge_get_results(get_results: list[dict]) -> dict:
175176

176177
def merge_and_sort_query_results(
177178
query_results: list[dict], k: int, reverse: bool = False
178-
) -> list[dict]:
179+
) -> dict:
179180
# Initialize lists to store combined data
180-
combined_distances = []
181-
combined_documents = []
182-
combined_metadatas = []
181+
combined = []
182+
seen_hashes = set() # To store unique document hashes
183183

184184
for data in query_results:
185-
combined_distances.extend(data["distances"][0])
186-
combined_documents.extend(data["documents"][0])
187-
combined_metadatas.extend(data["metadatas"][0])
185+
distances = data["distances"][0]
186+
documents = data["documents"][0]
187+
metadatas = data["metadatas"][0]
188+
189+
for distance, document, metadata in zip(distances, documents, metadatas):
190+
if isinstance(document, str):
191+
doc_hash = hashlib.md5(
192+
document.encode()
193+
).hexdigest() # Compute a hash for uniqueness
188194

189-
# Create a list of tuples (distance, document, metadata)
190-
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
195+
if doc_hash not in seen_hashes:
196+
seen_hashes.add(doc_hash)
197+
combined.append((distance, document, metadata))
191198

192199
# Sort the list based on distances
193200
combined.sort(key=lambda x: x[0], reverse=reverse)
194201

195-
# We don't have anything :-(
196-
if not combined:
197-
sorted_distances = []
198-
sorted_documents = []
199-
sorted_metadatas = []
200-
else:
201-
# Unzip the sorted list
202-
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
203-
204-
# Slicing the lists to include only k elements
205-
sorted_distances = list(sorted_distances)[:k]
206-
sorted_documents = list(sorted_documents)[:k]
207-
sorted_metadatas = list(sorted_metadatas)[:k]
202+
# Slice to keep only the top k elements
203+
sorted_distances, sorted_documents, sorted_metadatas = (
204+
zip(*combined[:k]) if combined else ([], [], [])
205+
)
208206

209-
# Create the output dictionary
210-
result = {
211-
"distances": [sorted_distances],
212-
"documents": [sorted_documents],
213-
"metadatas": [sorted_metadatas],
207+
# Create and return the output dictionary
208+
return {
209+
"distances": [list(sorted_distances)],
210+
"documents": [list(sorted_documents)],
211+
"metadatas": [list(sorted_metadatas)],
214212
}
215213

216-
return result
217-
218214

219215
def get_all_items_from_collections(collection_names: list[str]) -> dict:
220216
results = []

0 commit comments

Comments
 (0)