@@ -27,38 +27,92 @@ def embed_knowledge_base(self):
2727
2828 def normalize_query (self , query ):
2929 return query .lower ().strip ()
30-
31- def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 ):
30+
31+ def get_query_embedding (self , query , use_cpu = False ):
3232 normalized_query = self .normalize_query (query )
3333 query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
34- similarities = cosine_similarity (query_embedding , self .doc_embeddings )[0 ]
35- relevance_scores = []
36-
37- for i , doc in enumerate (self .knowledge_base ):
34+ if use_cpu :
35+ query_embedding = query_embedding .cpu ()
36+ return query_embedding
37+
38+ def get_doc_embeddings (self , use_cpu = False ):
39+ if use_cpu :
40+ return self .doc_embeddings .cpu ()
41+ return self .doc_embeddings
42+
43+ def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
44+ text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
45+ about_similarities = []
46+ for doc in self .knowledge_base :
3847 about_similarity = cosine_similarity (query_embedding , self .model .encode ([doc ["about" ]]))[0 ][0 ]
39- text_similarity = similarities [i ]
40-
41- combined_score = (0.3 * about_similarity ) + (0.7 * text_similarity )
42- if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold :
43- combined_score = max (about_similarity , text_similarity )
44-
45- relevance_scores .append ((i , combined_score ))
48+ about_similarities .append (about_similarity )
49+
50+ relevance_scores = self .compute_relevance_scores (text_similarities , about_similarities , high_match_threshold )
51+
52+ result = [
53+ {
54+ "index" : i ,
55+ "about" : doc ["about" ],
56+ "text" : doc ["text" ],
57+ "text_similarity" : text_similarities [i ],
58+ "about_similarity" : about_similarities [i ],
59+ "relevance_score" : relevance_scores [i ]
60+ }
61+ for i , doc in enumerate (self .knowledge_base )
62+ ]
4663
47- sorted_indices = sorted (relevance_scores , key = lambda x : x [1 ], reverse = True )
48- top_indices = [i for i , score in sorted_indices [:max_docs ] if score >= similarity_threshold ]
64+ return result
4965
50- retrieved_docs = [f'{ self .knowledge_base [i ]["about" ]} . { self .knowledge_base [i ]["text" ]} ' for i in top_indices ]
66+ def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 , use_cpu = False ):
67+ # Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
68+ # Set use_cpu=False to leverage GPU for better performance in production.
69+
70+ query_embedding = self .get_query_embedding (query , use_cpu )
71+ doc_embeddings = self .get_doc_embeddings (use_cpu )
5172
73+ doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
74+ retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
75+
5276 if not retrieved_docs :
53- max_index = np .argmax (similarities )
54- retrieved_docs .append (f'{ self .knowledge_base [max_index ]["about" ]} . { self .knowledge_base [max_index ]["text" ]} ' )
55-
56- return "\n \n " .join (retrieved_docs )
77+ retrieved_docs = self .get_fallback_doc ()
78+ return retrieved_docs
79+
5780
81+ def compute_relevance_scores (self , text_similarities , about_similarities , high_match_threshold ):
82+ relevance_scores = []
83+ for i , _ in enumerate (self .knowledge_base ):
84+ about_similarity = about_similarities [i ]
85+ text_similarity = text_similarities [i ]
86+ # If either about or text similarity is above the high match threshold, prioritize it
87+ if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold :
88+ combined_score = max (about_similarity , text_similarity )
89+ else :
90+ combined_score = (0.3 * about_similarity ) + (0.7 * text_similarity )
91+ relevance_scores .append (combined_score )
92+
93+ return relevance_scores
94+
95+ def get_top_docs (self , doc_scores , similarity_threshold , max_docs ):
96+ sorted_docs = sorted (doc_scores , key = lambda x : x ["relevance_score" ], reverse = True )
97+ # Filter and keep up to max_docs with relevance scores above the similarity threshold
98+ top_docs = [score for score in sorted_docs [:max_docs ] if score ["relevance_score" ] >= similarity_threshold ]
99+ return top_docs
100+
101+ def get_fallback_doc (self ):
102+ return [
103+ {
104+ "about" : "No Relevant Information Found" ,
105+ "text" : (
106+ "I'm sorry, I couldn't find any relevant information for your query. "
107+ "Please try rephrasing your question or ask about a different topic. "
108+ "For further assistance, you can visit our official website or reach out to our support team."
109+ )
110+ }
111+ ]
112+
58113 def answer_query_stream (self , query ):
59114 try :
60- normalized_query = self .normalize_query (query )
61- context = self .retrieve (normalized_query )
115+ context = self .get_context (query )
62116
63117 self .conversation_history .append ({"role" : "user" , "content" : query })
64118
@@ -118,5 +172,13 @@ def rebuild_embeddings(self):
118172 self .doc_embeddings = self .embed_knowledge_base () # Rebuild the embeddings
119173 print ("Embeddings have been rebuilt." )
120174
175+ def get_context (self , query ):
176+ normalized_query = self .normalize_query (query )
177+ retrieved_docs = self .retrieve (normalized_query )
178+ retrieved_text = []
179+ for doc in retrieved_docs :
180+ retrieved_text .append (f'{ doc ["about" ]} . { doc ["text" ]} ' )
181+ return "\n \n " .join (retrieved_text )
182+
121183# Instantiate the RAGSystem
122184rag_system = RAGSystem ()
0 commit comments