@@ -26,38 +26,92 @@ def embed_knowledge_base(self):
2626
2727 def normalize_query (self , query ):
2828 return query .lower ().strip ()
29-
30- def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 ):
29+
30+ def get_query_embedding (self , query , use_cpu = False ):
3131 normalized_query = self .normalize_query (query )
3232 query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
33- similarities = cosine_similarity (query_embedding , self .doc_embeddings )[0 ]
34- relevance_scores = []
35-
36- for i , doc in enumerate (self .knowledge_base ):
33+ if use_cpu :
34+ query_embedding = query_embedding .cpu ()
35+ return query_embedding
36+
37+ def get_doc_embeddings (self , use_cpu = False ):
38+ if use_cpu :
39+ return self .doc_embeddings .cpu ()
40+ return self .doc_embeddings
41+
42+ def compute_document_scores (self , query_embedding , doc_embeddings , high_match_threshold ):
43+ text_similarities = cosine_similarity (query_embedding , doc_embeddings )[0 ]
44+ about_similarities = []
45+ for doc in self .knowledge_base :
3746 about_similarity = cosine_similarity (query_embedding , self .model .encode ([doc ["about" ]]))[0 ][0 ]
38- text_similarity = similarities [i ]
39-
40- combined_score = (0.3 * about_similarity ) + (0.7 * text_similarity )
41- if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold :
42- combined_score = max (about_similarity , text_similarity )
43-
44- relevance_scores .append ((i , combined_score ))
47+ about_similarities .append (about_similarity )
48+
49+ relevance_scores = self .compute_relevance_scores (text_similarities , about_similarities , high_match_threshold )
50+
51+ result = [
52+ {
53+ "index" : i ,
54+ "about" : doc ["about" ],
55+ "text" : doc ["text" ],
56+ "text_similarity" : text_similarities [i ],
57+ "about_similarity" : about_similarities [i ],
58+ "relevance_score" : relevance_scores [i ]
59+ }
60+ for i , doc in enumerate (self .knowledge_base )
61+ ]
4562
46- sorted_indices = sorted (relevance_scores , key = lambda x : x [1 ], reverse = True )
47- top_indices = [i for i , score in sorted_indices [:max_docs ] if score >= similarity_threshold ]
63+ return result
4864
49- retrieved_docs = [f'{ self .knowledge_base [i ]["about" ]} . { self .knowledge_base [i ]["text" ]} ' for i in top_indices ]
65+ def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 , use_cpu = False ):
66+ # Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
67+ # Set use_cpu=False to leverage GPU for better performance in production.
68+
69+ query_embedding = self .get_query_embedding (query , use_cpu )
70+ doc_embeddings = self .get_doc_embeddings (use_cpu )
5071
72+ doc_scores = self .compute_document_scores (query_embedding , doc_embeddings , high_match_threshold )
73+ retrieved_docs = self .get_top_docs (doc_scores , similarity_threshold , max_docs )
74+
5175 if not retrieved_docs :
52- max_index = np .argmax (similarities )
53- retrieved_docs .append (f'{ self .knowledge_base [max_index ]["about" ]} . { self .knowledge_base [max_index ]["text" ]} ' )
54-
55- return "\n \n " .join (retrieved_docs )
76+ retrieved_docs = self .get_fallback_doc ()
77+ return retrieved_docs
78+
5679
80+ def compute_relevance_scores (self , text_similarities , about_similarities , high_match_threshold ):
81+ relevance_scores = []
82+ for i , _ in enumerate (self .knowledge_base ):
83+ about_similarity = about_similarities [i ]
84+ text_similarity = text_similarities [i ]
85+ # If either about or text similarity is above the high match threshold, prioritize it
86+ if about_similarity >= high_match_threshold or text_similarity >= high_match_threshold :
87+ combined_score = max (about_similarity , text_similarity )
88+ else :
89+ combined_score = (0.3 * about_similarity ) + (0.7 * text_similarity )
90+ relevance_scores .append (combined_score )
91+
92+ return relevance_scores
93+
94+ def get_top_docs (self , doc_scores , similarity_threshold , max_docs ):
95+ sorted_docs = sorted (doc_scores , key = lambda x : x ["relevance_score" ], reverse = True )
96+ # Filter and keep up to max_docs with relevance scores above the similarity threshold
97+ top_docs = [score for score in sorted_docs [:max_docs ] if score ["relevance_score" ] >= similarity_threshold ]
98+ return top_docs
99+
100+ def get_fallback_doc (self ):
101+ return [
102+ {
103+ "about" : "No Relevant Information Found" ,
104+ "text" : (
105+ "I'm sorry, I couldn't find any relevant information for your query. "
106+ "Please try rephrasing your question or ask about a different topic. "
107+ "For further assistance, you can visit our official website or reach out to our support team."
108+ )
109+ }
110+ ]
111+
57112 def answer_query_stream (self , query ):
58113 try :
59- normalized_query = self .normalize_query (query )
60- context = self .retrieve (normalized_query )
114+ context = self .get_context (query )
61115
62116 self .conversation_history .append ({"role" : "user" , "content" : query })
63117
@@ -117,5 +171,13 @@ def rebuild_embeddings(self):
117171 self .doc_embeddings = self .embed_knowledge_base () # Rebuild the embeddings
118172 print ("Embeddings have been rebuilt." )
119173
174+ def get_context (self , query ):
175+ normalized_query = self .normalize_query (query )
176+ retrieved_docs = self .retrieve (normalized_query )
177+ retrieved_text = []
178+ for doc in retrieved_docs :
179+ retrieved_text .append (f'{ doc ["about" ]} . { doc ["text" ]} ' )
180+ return "\n \n " .join (retrieved_text )
181+
120182# Instantiate the RAGSystem
121183rag_system = RAGSystem ()
0 commit comments