-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclient.py
More file actions
32 lines (25 loc) · 1.08 KB
/
client.py
File metadata and controls
32 lines (25 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import chromadb
from sentence_transformers import SentenceTransformer
class Client:
def __init__(self, client: chromadb.Client, embed_model: SentenceTransformer, collection: str):
self.client = client
self.embed_model = embed_model
self.collection = collection
def retrieve_relevant_studies(self, title: str, description: str, existing_study: str, n_results=5):
query = f'{title} [SEP] {description}'
query_embedding = self.embed_model.encode(query).tolist()
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results + 1,
)
filtered_results = []
for id, distance, document in zip(results['ids'][0], results['distances'][0], results['documents'][0]):
if id != existing_study:
filtered_results.append({
"id": id,
"distance": distance,
"document": document,
})
if len(filtered_results) == n_results:
break
return filtered_results