Skip to content

Commit a876a4b

Browse files
added code for cache
1 parent 6f31b36 commit a876a4b

11 files changed

+637
-9
lines changed
Binary file not shown.

FASTAPI-DEPLOYMENT/rhl_fastapi_v2_modify.py

Lines changed: 214 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ async def _background_update_and_save(user_id: str, user_message: str, bot_reply
377377
378378
If medical advice is requested, politely decline and redirect to medical questions without technical input.
379379
380-
Examples:
380+
Examples:
381381
382382
383383
@@ -1109,8 +1109,161 @@ def _verify_with_llm(self, answer: str, description: str) -> bool:
11091109
print(f"[VIDEO_SYSTEM] LLM verification failed: {e}")
11101110
return False
11111111

1112-
# Global video matching system
1113-
video_system: VideoMatchingSystem = None
1112+
# -------------------- CACHE SYSTEM (BERT + LLM APPROACH) --------------------
1113+
class CacheSystem:
1114+
def __init__(self, cache_file_path: str = "D:\\RHL-WH\\RHL-FASTAPI\\FILES\\cache_questions.xlsx"):
1115+
"""Initialize the cache system using BERT similarity + LLM verification"""
1116+
self.cache_file_path = cache_file_path
1117+
self.question_list = [] # List of cached questions
1118+
self.answer_list = [] # List of corresponding answers
1119+
1120+
# Load cache data
1121+
self._load_cache_data()
1122+
1123+
# Initialize BERT model for similarity (reuse video system's model)
1124+
if self.question_list:
1125+
print("[CACHE_SYSTEM] Loading BERT model for cache similarity...")
1126+
self.similarity_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
1127+
print("[CACHE_SYSTEM] BERT model loaded successfully")
1128+
else:
1129+
self.similarity_model = None
1130+
1131+
def _load_cache_data(self):
1132+
"""Load cache data from Excel file"""
1133+
try:
1134+
df = pd.read_excel(self.cache_file_path)
1135+
print(f"[CACHE_SYSTEM] Loaded {len(df)} cached Q&A pairs from {self.cache_file_path}")
1136+
1137+
# Create simple lists
1138+
for idx, row in df.iterrows():
1139+
question = row['question'].strip()
1140+
answer = row['answer'].strip()
1141+
1142+
if question and answer:
1143+
self.question_list.append(question)
1144+
self.answer_list.append(answer)
1145+
1146+
print(f"[CACHE_SYSTEM] Created cache with {len(self.question_list)} questions")
1147+
print(f"[CACHE_SYSTEM] Sample cached questions:")
1148+
for i, q in enumerate(self.question_list[:3]):
1149+
print(f" {i}: {q[:80]}...")
1150+
1151+
except Exception as e:
1152+
print(f"[CACHE_SYSTEM] Error loading cache data: {e}")
1153+
self.question_list = []
1154+
self.answer_list = []
1155+
1156+
def check_cache(self, reformulated_query: str) -> Optional[str]:
1157+
"""Check if reformulated query matches any cached question using BERT + LLM verification"""
1158+
if not self.question_list or not self.similarity_model:
1159+
print("[CACHE_SYSTEM] No cache data or model available")
1160+
return None
1161+
1162+
print("="*60)
1163+
print("CACHE SYSTEM BLOCK")
1164+
print("="*60)
1165+
print(f"[CACHE_SYSTEM] Checking cache for: {reformulated_query[:100]}...")
1166+
1167+
# Step 1: BERT Semantic Similarity
1168+
print("[CACHE_SYSTEM] Step 1: Computing BERT semantic similarities...")
1169+
cache_start = time.perf_counter()
1170+
1171+
# Encode reformulated query and all cached questions
1172+
query_embedding = self.similarity_model.encode([reformulated_query])
1173+
question_embeddings = self.similarity_model.encode(self.question_list)
1174+
1175+
# Compute cosine similarities
1176+
similarities = cosine_similarity(query_embedding, question_embeddings)[0]
1177+
1178+
# Find best match
1179+
best_idx = np.argmax(similarities)
1180+
best_similarity = similarities[best_idx]
1181+
1182+
cache_end = time.perf_counter()
1183+
print(f"[CACHE_SYSTEM] BERT similarity computation took {cache_end - cache_start:.3f} seconds")
1184+
print(f"[CACHE_SYSTEM] Best similarity score: {best_similarity:.3f}")
1185+
print(f"[CACHE_SYSTEM] Best cached question: {self.question_list[best_idx][:100]}...")
1186+
1187+
# Step 2: Combined LLM Verification + Reframing (only for top match)
1188+
if best_similarity >= 0.4: # Higher threshold for cache (more strict)
1189+
print("[CACHE_SYSTEM] Step 2: Combined LLM verification and reframing...")
1190+
llm_start = time.perf_counter()
1191+
1192+
result = self._verify_and_reframe_cache(reformulated_query, self.question_list[best_idx], self.answer_list[best_idx])
1193+
1194+
llm_end = time.perf_counter()
1195+
print(f"[CACHE_SYSTEM] Combined LLM verification and reframing took {llm_end - llm_start:.3f} seconds")
1196+
1197+
if result:
1198+
print(f"[CACHE_SYSTEM] Cache HIT! Returning reframed answer")
1199+
print("="*60)
1200+
return result
1201+
else:
1202+
print("[CACHE_SYSTEM] LLM verification failed - cache miss")
1203+
print("="*60)
1204+
return None
1205+
else:
1206+
print(f"[CACHE_SYSTEM] Similarity score {best_similarity:.3f} below threshold 0.4 - cache miss")
1207+
print("="*60)
1208+
return None
1209+
1210+
def _verify_and_reframe_cache(self, reformulated_query: str, cached_question: str, cached_answer: str) -> Optional[str]:
1211+
"""Combined verification and reframing: Check if cached answer can answer the query, and reframe if yes"""
1212+
prompt = f"""Analyze if the cached answer can be used to answer the reformulated query, and if yes, reframe it appropriately.
1213+
1214+
Reformulated Query: {reformulated_query}
1215+
1216+
Original Cached Question: {cached_question}
1217+
1218+
Cached Answer: {cached_answer}
1219+
1220+
Instructions:
1221+
1. FIRST: Determine if the cached answer contains information that can answer the reformulated query
1222+
2. If YES: Reframe the cached answer to directly address the reformulated query
1223+
3. If NO: Return "NULL"
1224+
1225+
Rules for Verification:
1226+
- Return reframed answer ONLY if the cached answer contains relevant information for the query
1227+
- Return "NULL" if the cached answer is about a different medical topic entirely
1228+
- Return "NULL" if the cached answer lacks the specific information asked for
1229+
1230+
Rules for Reframing (if applicable):
1231+
- Use ONLY information from the cached answer - NO external knowledge
1232+
- Maintain ALL medical facts from the cached answer
1233+
- Adjust tense, flow, and structure to match the question asked
1234+
- Do NOT add any external information not present in the cached answer
1235+
- Do NOT remove any medical information from the cached answer
1236+
- Keep the answer concise and directly relevant
1237+
1238+
Examples:
1239+
- Query: "What causes fever?" + Cached Answer: "Fever symptoms include..." → NULL (no cause info)
1240+
- Query: "What causes fever?" + Cached Answer: "Fever is caused by infections..." → Reframe to focus on causes
1241+
- Query: "How to treat jaundice?" + Cached Answer: "Jaundice symptoms are..." → NULL (no treatment info)
1242+
- Query: "What are signs of dehydration?" + Cached Answer: "Dehydration signs include..." → Reframe to focus on signs
1243+
1244+
Response: Return the reframed answer if applicable, or "NULL" if not applicable."""
1245+
1246+
try:
1247+
response = gemini_llm.invoke([HumanMessage(content=prompt)]).content.strip()
1248+
print(f"[CACHE_SYSTEM] LLM response preview: {response[:200]}...")
1249+
1250+
# Check if LLM returned NULL
1251+
if response.upper().strip() == "NULL":
1252+
print(f"[CACHE_SYSTEM] LLM determined cached answer cannot answer the query")
1253+
return None
1254+
1255+
# Return the reframed answer
1256+
print(f"[CACHE_SYSTEM] LLM provided reframed answer")
1257+
return response
1258+
1259+
except Exception as e:
1260+
print(f"[CACHE_SYSTEM] Combined verification and reframing failed: {e}")
1261+
# Fallback: return None to trigger RAG pipeline
1262+
print(f"[CACHE_SYSTEM] Falling back to RAG pipeline")
1263+
return None
1264+
1265+
# Global cache system
1266+
cache_system: CacheSystem = None
11141267

11151268
# -------------------- MAIN PIPELINE (called by API) --------------------
11161269
async def medical_pipeline_api(user_id: str, user_message: str, background_tasks: BackgroundTasks) -> Dict[str, Any]:
@@ -1162,17 +1315,64 @@ async def medical_pipeline_api(user_id: str, user_message: str, background_tasks
11621315

11631316
# Skip the old reformulation step since it's now combined above
11641317

1165-
1318+
# ---- CACHE CHECK (NEW: BERT + LLM APPROACH) ----
1319+
print("[pipeline] Step 3: Cache check using BERT + LLM verification...")
1320+
cached_answer = None
1321+
if cache_system:
1322+
cache_start = time.perf_counter()
1323+
cached_answer = cache_system.check_cache(rewritten)
1324+
cache_end = time.perf_counter()
1325+
print(f"[pipeline] Cache check took {cache_end - cache_start:.3f} seconds")
1326+
1327+
if cached_answer:
1328+
print("[pipeline] CACHE HIT! Skipping RAG pipeline...")
1329+
1330+
# Apply correction prefix if needed
1331+
if label != "FOLLOW_UP" and correction:
1332+
correction_msg = "I guess you meant " + " and ".join(correction.values())
1333+
cached_answer = correction_msg + "\n" + cached_answer
1334+
1335+
# Find relevant video URL for cached answer
1336+
print("[pipeline] Step 4: Finding relevant video for cached answer...")
1337+
video_url = None
1338+
if video_system:
1339+
video_start = time.perf_counter()
1340+
video_url = video_system.find_relevant_video(cached_answer)
1341+
video_end = time.perf_counter()
1342+
print(f"[pipeline] Video matching took {video_end - video_start:.3f} secs")
1343+
if video_url:
1344+
print(f"[pipeline] Found relevant video: {video_url}")
1345+
else:
1346+
print("[pipeline] No relevant video found")
1347+
1348+
# Schedule background save for cached answer
1349+
print("[pipeline] schedule background save: cached_answer")
1350+
background_tasks.add_task(_background_update_and_save, user_id, user_message, cached_answer, "answer", history_pairs, current_summary)
1351+
print("[pipeline] done with cached answer")
1352+
t_end = time.perf_counter()
1353+
timer.total("request")
1354+
print(f"total took {t_end - start_time:.2f} secs")
1355+
1356+
# Return cached response with video URL
1357+
response = {"answer": cached_answer, "intent": "answer", "follow_up": None}
1358+
if video_url:
1359+
response["video_url"] = video_url
1360+
else:
1361+
response["video_url"] = None
1362+
1363+
return response
1364+
1365+
print("[pipeline] CACHE MISS! Proceeding with RAG pipeline...")
11661366

11671367
# ---- HYBRID RETRIEVAL ----
1168-
print("[pipeline] Step 3: hybrid_retrieve")
1368+
print("[pipeline] Step 4: hybrid_retrieve")
11691369
candidates = hybrid_retrieve(rewritten) # vector + bm25 + rerank
11701370
t4 = time.perf_counter()
11711371
timer.mark("hybrid_retrieve")
11721372
print(f"retrieval took {t4 - t3:.2f} secs")
11731373
print(f"[pipeline] retrieved {len(candidates)} candidates")
11741374

1175-
print("[pipeline] Step 4: judge_sufficiency")
1375+
print("[pipeline] Step 5: judge_sufficiency")
11761376
judge = judge_sufficiency(rewritten, candidates)
11771377
t5 = time.perf_counter()
11781378
timer.mark("judge_sufficiency")
@@ -1190,7 +1390,7 @@ async def medical_pipeline_api(user_id: str, user_message: str, background_tasks
11901390
sec = fc["meta"].get("section") if fc.get("meta") else None
11911391
followup_q = sec or (fc["text"])
11921392

1193-
print("[pipeline] Step 5: synthesize_answer")
1393+
print("[pipeline] Step 6: synthesize_answer")
11941394
answer = synthesize_answer(rewritten, top4, followup_q, gemini_llm)
11951395
t6 = time.perf_counter()
11961396
timer.mark("synthesize_answer")
@@ -1202,7 +1402,7 @@ async def medical_pipeline_api(user_id: str, user_message: str, background_tasks
12021402
answer = correction_msg + "\n" + answer
12031403

12041404
# Find relevant video URL
1205-
print("[pipeline] Step 6: Finding relevant video...")
1405+
print("[pipeline] Step 7: Finding relevant video...")
12061406
video_url = None
12071407
if video_system:
12081408
video_start = time.perf_counter()
@@ -1251,7 +1451,7 @@ async def medical_pipeline_api(user_id: str, user_message: str, background_tasks
12511451
# -------------------- API ENDPOINTS --------------------
12521452
@app.on_event("startup")
12531453
async def startup_event():
1254-
global embedding_model, reranker, pinecone_index, llm, summarizer_llm, reformulate_llm, classifier_llm, gemini_llm, EMBED_DIM, video_system
1454+
global embedding_model, reranker, pinecone_index, llm, summarizer_llm, reformulate_llm, classifier_llm, gemini_llm, EMBED_DIM, video_system, cache_system
12551455
print("[startup] Initializing models and Pinecone client...")
12561456
t = CheckpointTimer("startup")
12571457
embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
@@ -1287,6 +1487,11 @@ async def startup_event():
12871487
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-lite", api_key=GOOGLE_API_KEY)
12881488
t.mark("init_llms")
12891489

1490+
# Initialize cache system
1491+
print("[startup] Initializing cache system...")
1492+
cache_system = CacheSystem()
1493+
t.mark("init_cache_system")
1494+
12901495
# Initialize video matching system
12911496
print("[startup] Initializing video matching system...")
12921497
video_system = VideoMatchingSystem()

FILES/cache_questions_test.xlsx

5.45 KB
Binary file not shown.

chat_history.db

200 KB
Binary file not shown.

create_test_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

test_cache_edge_cases.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test edge cases for the combined cache system
4+
"""
5+
6+
import requests
7+
import json
8+
import time
9+
10+
def test_cache_edge_cases():
11+
"""Test edge cases for the combined cache verification and reframing"""
12+
13+
base_url = "http://localhost:8000"
14+
15+
# Edge case test scenarios
16+
edge_cases = [
17+
{
18+
"question": "What are symptoms of jaundice?",
19+
"description": "Exact match - should work"
20+
},
21+
{
22+
"question": "What causes jaundice?",
23+
"description": "Different focus - cached answer about symptoms, should return NULL"
24+
},
25+
{
26+
"question": "How to treat jaundice?",
27+
"description": "Different focus - cached answer about symptoms, should return NULL"
28+
},
29+
{
30+
"question": "What are complications of jaundice?",
31+
"description": "Different focus - cached answer about symptoms, should return NULL"
32+
},
33+
{
34+
"question": "What are signs of jaundice?",
35+
"description": "Semantic match - symptoms = signs, should reframe"
36+
},
37+
{
38+
"question": "What are jaundice symptoms?",
39+
"description": "Word order change - should reframe"
40+
},
41+
{
42+
"question": "Tell me about jaundice symptoms",
43+
"description": "Different phrasing - should reframe"
44+
}
45+
]
46+
47+
print("="*80)
48+
print("CACHE SYSTEM EDGE CASES TEST")
49+
print("="*80)
50+
print("Testing edge cases for combined verification and reframing...")
51+
print()
52+
53+
for i, test_case in enumerate(edge_cases, 1):
54+
print(f"--- EDGE CASE {i}: {test_case['description']} ---")
55+
print(f"Question: {test_case['question']}")
56+
57+
try:
58+
# Make API request
59+
start_time = time.time()
60+
response = requests.get(
61+
f"{base_url}/chat",
62+
params={
63+
"user_id": f"edge_case_user_{i}",
64+
"message": test_case['question']
65+
},
66+
timeout=30
67+
)
68+
end_time = time.time()
69+
70+
if response.status_code == 200:
71+
data = response.json()
72+
response_time = end_time - start_time
73+
74+
print(f"Response time: {response_time:.2f} seconds")
75+
76+
# Determine if cache hit or miss
77+
if response_time < 3.0:
78+
print("RESULT: CACHE HIT")
79+
answer = data.get('answer', '')
80+
print(f"Answer preview: {answer[:150]}...")
81+
82+
# Check if answer seems appropriate
83+
if "jaundice" in answer.lower():
84+
print("CONTENT: Answer contains jaundice information ✓")
85+
else:
86+
print("CONTENT: Answer doesn't contain jaundice information ❌")
87+
88+
else:
89+
print("RESULT: CACHE MISS (went through RAG)")
90+
answer = data.get('answer', '')
91+
print(f"Answer preview: {answer[:150]}...")
92+
93+
else:
94+
print(f"API Error: {response.status_code} - {response.text}")
95+
96+
except requests.exceptions.RequestException as e:
97+
print(f"Request failed: {e}")
98+
except Exception as e:
99+
print(f"Unexpected error: {e}")
100+
101+
print("-" * 60)
102+
print()
103+
104+
print("="*80)
105+
print("EDGE CASES TEST COMPLETE")
106+
print("="*80)
107+
print("Expected behavior:")
108+
print("- Exact matches: Cache hit with reframed answer")
109+
print("- Different focus (causes vs symptoms): Cache miss, goes to RAG")
110+
print("- Semantic matches (symptoms = signs): Cache hit with reframed answer")
111+
print("- Different phrasings: Cache hit with reframed answer")
112+
print("="*80)
113+
114+
if __name__ == "__main__":
115+
test_cache_edge_cases()
116+

0 commit comments

Comments
 (0)