-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmain.py
More file actions
62 lines (54 loc) · 1.41 KB
/
main.py
File metadata and controls
62 lines (54 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from sentence_transformers import SentenceTransformer
import pymongo
from pprint import pprint
# import the embedding model
# https://huggingface.co/obrizum/all-MiniLM-L6-v2
model = SentenceTransformer('obrizum/all-MiniLM-L6-v2')
# mongo
mongo_uri = ""
# connection object
connection = pymongo.MongoClient(mongo_uri)
vector_collection = connection['eap']['vector']
# strings as an array that we will
products = [
{"name": "Mozzarella"},
{"name": "Parmesan"},
{"name": "Cheddar"},
{"name": "Brie"},
{"name": "Swiss"},
{"name": "Gruyere"},
{"name": "Feta"},
{"name": "Gouda"},
{"name": "Provolone"},
{"name": "Monterey Jack"}
]
# create a new embedding field for each product object
for product in products:
# convert to embedding, then to array
embeddings = model.encode(product['name']).tolist()
product['embedding'] = embeddings
vector_collection.insert(product)
query = "cheese"
vector_query = model.encode(query).tolist()
pipeline = [
{
"$search": {
"knnBeta": {
"vector": vector_query,
"path": "embedding",
"k": 10
}
}
},
{
"$project": {
"embedding": 0,
"_id": 0,
'score': {
'$meta': 'searchScore'
}
}
}
]
results = list(connection[database][collection].aggregate(pipeline))
print(results)