@@ -30,17 +30,14 @@ def embed_knowledge_base(self):
30
30
def normalize_query (self , query ):
31
31
return query .lower ().strip ()
32
32
33
- def get_query_embedding (self , query , use_cpu = True ):
33
+ def get_query_embedding (self , query ):
34
34
normalized_query = self .normalize_query (query )
35
35
query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
36
- if use_cpu :
37
- query_embedding = query_embedding .cpu ()
36
+ query_embedding = query_embedding .cpu ()
38
37
return query_embedding
39
38
40
- def get_doc_embeddings (self , use_cpu = True ):
41
- if use_cpu :
42
- return self .doc_embeddings .cpu ()
43
- return self .doc_embeddings
39
+ def get_doc_embeddings (self ):
40
+ return self .doc_embeddings .cpu ()
44
41
45
42
def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
46
43
text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
@@ -66,12 +63,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
66
63
67
64
return result
68
65
69
- def retrieve (self , query , similarity_threshold = 0.4 , high_match_threshold = 0.8 , max_docs = 5 , use_cpu = True ):
70
- # Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
71
- # Set use_cpu=False to leverage GPU for better performance in production.
72
-
73
- query_embedding = self .get_query_embedding (query , use_cpu )
74
- doc_embeddings = self .get_doc_embeddings (use_cpu )
66
+ def retrieve (self , query , similarity_threshold = 0.4 , high_match_threshold = 0.8 , max_docs = 5 ):
67
+ query_embedding = self .get_query_embedding (query )
68
+ doc_embeddings = self .get_doc_embeddings ()
75
69
76
70
doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
77
71
retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
@@ -149,11 +143,11 @@ def answer_query_stream(self, query):
149
143
150
144
collected_messages = []
151
145
for chunk in stream :
152
- if chunk ['choices' ][0 ]['finish_reason' ] is not None :
153
- break
154
146
content = chunk ['choices' ][0 ]['delta' ].get ('content' , '' )
155
147
collected_messages .append (content )
156
148
yield content
149
+ if chunk ['choices' ][0 ].get ('finish_reason' ) is not None :
150
+ break
157
151
158
152
if len (citations ) > 0 :
159
153
yield "\n \n References:\n " + "\n " .join (citations )
@@ -193,3 +187,6 @@ def get_context(self, retrieved_docs):
193
187
for doc in retrieved_docs :
194
188
retrieved_text .append (f"{ doc ['about' ]} . { doc ['text' ]} " )
195
189
return "\n \n " .join (retrieved_text )
190
+
191
+ # # Instantiate the RAGSystem
192
+ # rag_system = RAGSystem()
0 commit comments