2
2
import json
3
3
import os
4
4
import sys
5
+ import logging
5
6
from datetime import date
6
7
from sentence_transformers import SentenceTransformer
7
8
import numpy as np
8
9
from sklearn .metrics .pairwise import cosine_similarity
9
- from embeddings import load_model
10
10
import traceback
11
11
12
12
openai .api_base = os .getenv ("OPENAI_BASE_URL" )
15
15
class RAGSystem :
16
16
def __init__ (self , knowledge_base_path = './data/knowledge_base.json' ):
17
17
self .knowledge_base_path = knowledge_base_path
18
+
18
19
self .knowledge_base = self .load_knowledge_base ()
19
- self .model = load_model ()
20
- self .doc_embeddings = self .embed_knowledge_base ()
20
+ self .model = SentenceTransformer ("all-MiniLM-L6-v2" )
21
+
22
+ # load existing embeddings if available
23
+ logging .info ("Embedding knowledge base..." )
24
+ if os .path .exists ('./data/doc_embeddings.npy' ):
25
+ self .doc_embeddings = np .load ('./data/doc_embeddings.npy' )
26
+ logging .info ("Loaded existing document embeddings from disk." )
27
+ else :
28
+ logging .info ("No existing document embeddings found, creating new embeddings." )
29
+ self .doc_embeddings = self .embed_knowledge_base ()
30
+ # cache doc_embeddings to disk
31
+ np .save ('./data/doc_embeddings.npy' , self .doc_embeddings .cpu ().numpy ())
32
+ logging .info ("Knowledge base embeddings created" )
21
33
self .conversation_history = []
22
34
23
35
def load_knowledge_base (self ):
@@ -38,7 +50,7 @@ def get_query_embedding(self, query):
38
50
return query_embedding
39
51
40
52
def get_doc_embeddings (self ):
41
- return self .doc_embeddings . cpu ()
53
+ return self .doc_embeddings
42
54
43
55
def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
44
56
text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
@@ -188,3 +200,11 @@ def get_context(self, retrieved_docs):
188
200
for doc in retrieved_docs :
189
201
retrieved_text .append (f"{ doc ['about' ]} . { doc ['text' ]} " )
190
202
return "\n \n " .join (retrieved_text )
203
+
204
+ if __name__ == "__main__" :
205
+ logging .basicConfig (
206
+ level = logging .INFO ,
207
+ format = "%(asctime)s %(levelname)s %(message)s" ,
208
+ datefmt = "%Y-%m-%d %H:%M:%S"
209
+ )
210
+ RAGSystem ()
0 commit comments