-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmedical_chatbot_main_application.py
More file actions
363 lines (322 loc) · 13.6 KB
/
medical_chatbot_main_application.py
File metadata and controls
363 lines (322 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import os
import torch
from pymongo import MongoClient
from sentence_transformers import SentenceTransformer, CrossEncoder
from typing import List
from vertexai.preview.generative_models import GenerativeModel
import vertexai
import gradio as gr
# === CONFIGURATION ===
MONGODB_URI = "" # MongoDB CLuster URL
SERVICE_ACCOUNT = "" # Service Account
PROJECT_ID = "" # Project ID from GCP
DB_NAME = "" # DB we created in MongoDB
COLLECTION_NAME = "" #Collection in the DB
EMBEDDING_DIM = 384 # For all-MiniLM-L6-v2
# === DEVICE SETUP ===
device = "cuda" if torch.cuda.is_available() else "cpu"
# === GEMINI SETUP ===
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = SERVICE_ACCOUNT
vertexai.init(project=PROJECT_ID, location="us-central1")
# Gemini model for answering clinical questions based on transcription context.
gemini = GenerativeModel(
"models/gemini-2.0-flash-lite-001",
system_instruction="""
You are a specialized medical AI assistant trained to analyze medical transcriptions.
Your role is to extract relevant medical information and provide accurate, helpful responses based solely
on the provided medical context. Always prioritize patient safety and medical accuracy.
"""
)
# Gemini model configuration for symptom-based diagnosis responses
gemini_diag = GenerativeModel(
"models/gemini-2.0-flash-lite-001",
system_instruction="""
You are a warm and concise medical assistant. When given symptom descriptions and context, your job is to suggest the most likely diagnosis and the next clinical steps.
Speak clearly and kindly like a helpful doctor. DO NOT reference records, documents, or phrases like "it's important to figure out what's going on" or "there could be a few possibilities."
Instead, go straight to the point using a natural, reassuring tone.
Example: "Hey! Based on what you're experiencing, it looks like you might be dealing with ____. A good next step would be _____."
"""
)
# === MODEL SETUP ===
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device=device)
# === MONGODB SETUP ===
client = MongoClient(MONGODB_URI)
db = client[DB_NAME]
collection = db[COLLECTION_NAME]
# === SEARCH + DIAGNOSIS ===
def suggest_diagnosis(symptom_description: str, top_k: int = 10) -> List[str]:
query_vec = embed_model.encode(symptom_description, convert_to_tensor=True).tolist()
pipeline = [
{"$search": {"index": "rich_vec_index", "knnBeta": {"vector": query_vec, "path": "embedding", "k": top_k * 2}}},
{"$project": {"transcription": 1, "parsed_entities": 1, "score": {"$meta": "searchScore"}}}
]
raw_results = list(collection.aggregate(pipeline))
pairs = [(symptom_description, doc["transcription"]) for doc in raw_results[:top_k]]
rerank_scores = reranker.predict(pairs)
reranked = sorted(zip(rerank_scores, raw_results[:top_k]), key=lambda x: x[0], reverse=True)
context = "\n---\n".join([doc["transcription"] for _, doc in reranked])
prompt = f"""
PATIENT SYMPTOMS:
{symptom_description}
SIMILAR CLINICAL CASES:
{context}
QUESTION:
What is the likely diagnosis based on these symptoms and similar cases?
"""
response = gemini_diag.generate_content(prompt)
return response.text.strip()
# === GENERAL Q/A method to ask GEMINI ===
def ask_gemini(query, docs):
if not docs:
return "❌ No documents available for Gemini."
context = "\n\n---\n\n".join([doc["transcription"] for doc in docs])
prompt = f"""
You are a clinical assistant. Below are anonymized clinical notes from patient records.
Your job is to answer the user's question **using only the information contained in the notes**.
DO NOT guess or refer to patients not mentioned in the notes.
CLINICAL NOTES:
{context}
QUESTION:
{query}
INSTRUCTIONS:
- Use only facts stated in the notes.
- If insufficient info, respond: "Not enough information in the notes."
"""
response = gemini.generate_content(prompt)
return response.text.strip()
# === VECTOR SEARCH + RERANKING ===
def search_with_reranking(query, top_k=10):
query_vec = embed_model.encode(query, convert_to_tensor=True).tolist()
pipeline = [
{"$search": {"index": "rich_vec_index", "knnBeta": {"vector": query_vec, "path": "embedding", "k": top_k * 2}}},
{"$project": {
"transcription": 1,
"sample_name": 1,
"description": 1,
"medical_specialty": 1,
"age": 1,
"gender": 1,
"parsed_entities": 1,
"score": {"$meta": "searchScore"}
}}
]
raw_results = list(collection.aggregate(pipeline))
pairs = [(query, doc["transcription"]) for doc in raw_results[:top_k]]
rerank_scores = reranker.predict(pairs)
reranked = sorted(zip(rerank_scores, raw_results[:top_k]), key=lambda x: x[0], reverse=True)
return [doc for _, doc in reranked]
# Handles research-based queries by retrieving and summarizing relevant clinical records using Gemini
def research_ui(question):
if not question.strip():
return "Please enter a question.", None
docs = search_with_reranking(question)
answer = ask_gemini(question, docs)
return "💬", answer
# Analyzes patient symptoms and suggests a probable diagnosis using similar clinical records and Gemini
def diagnose_ui(symptoms):
if not symptoms.strip():
return "Please enter symptoms.", None
result = suggest_diagnosis(symptoms)
return "🧠Here is the Diagnosis Explanation:", result
# Updates chatbot with research-based response by calling research_ui and appending formatted messages to chat history
def handle_research(user_input, chat_history):
header, response = research_ui(user_input)
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": f"{header}\n\n{response}"})
return "", chat_history
# Updates chatbot with diagnosis explanation by calling diagnose_ui and appending formatted messages to chat history
def handle_diagnose(user_input, chat_history):
header, response = diagnose_ui(user_input)
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": f"{header}\n\n{response}"})
return "", chat_history
# User Interface code linked with appropriate methods.
with gr.Blocks(
title="Clinical Assistant",
theme=gr.themes.Base(),
fill_height=True,
css="""
#custom-header-bar {
width: 100%;
display: flex;
align-items: center;
justify-content: space-between;
background: #e0f2f1;
border-radius: 10px 10px 0 0;
border-bottom: 1px solid #b2dfdb;
padding: 10px 0 8px 0;
margin-bottom: 0px;
}
.header-side {
flex: 1;
text-align: left;
font-size: 1rem;
color: #888;
font-weight: 400;
padding-left: 18px;
letter-spacing: 0.5px;
}
.header-center {
flex: 2;
text-align: center;
font-size: 2rem;
# font-weight: bold;
color: #222;
letter-spacing: 0.5px;
}
.header-center-upper-font {
font-weight: bold;
font-family: 'Magneto', 'Brush Script MT', cursive, sans-serif;
}
.header-side.right {
text-align: right;
padding-left: 0;
padding-right: 18px;
}
#sidebar_logo img {
border-radius: 50%;
object-fit: cover;
height: 120px;
width: 120px;
display: block;
margin: auto;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15);
background-color: #e0f6fa;
}
div.styler.svelte-lngued {
background-color: transparent !important;
background: none !important;
box-shadow: none !important;
border: none !important;
outline: none !important;
padding-bottom: 0 !important;
}
#button_row {
padding-bottom: 1%;
display: flex;
justify-content: space-between;
gap: 12px;
margin-top: 10px;
background: none !important;
background-color: transparent !important;
box-shadow: none !important;
border: none !important;
outline: none !important;
}
body { background-color: #f2f7f9; }
#chatbox {
background-color: #ffffff;
padding: 20px;
border-radius: 12px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.03);
position: relative;
overflow: hidden !important;
}
#sidebar {
background-color: #e0f6fa;
padding: 20px;
border-radius: 12px;
font-family: 'Arial';
color: #007791;
}
textarea {
border-color: #007791;
padding-right: 100px;
padding-bottom: 20px;
position: relative;
resize: none;
}
.gr-textbox-container { position: relative; }
button {
background-color: #007791 !important;
color: white !important;
border-radius: 20px;
font-size: 14px;
padding: 6px 16px;
white-space: nowrap;
font-weight: bold;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
@media only screen and (max-width: 768px) {
#layout { flex-direction: column; }
#button_row { flex-direction: column; gap: 8px; justify-content: center; }
#custom-header-bar { font-size: 1rem; }
.header-center { font-size: 1.2rem; }
}
#chatbox_container button, #chatbox_container svg {
display: none !important;
}
footer,
footer * {
display: none !important;
}
body {
padding-bottom: 0px !important;
}
button[aria-label="Fullscreen"] {
display: none !important;
}
"""
) as iface:
with gr.Row(elem_id="layout"):
with gr.Column(scale=3, elem_id="chatbox"):
# --- Custom Header (matches attached image) ---
gr.HTML(
'''
<div id="custom-header-bar">
<span class="header-side">powered by MongoDB</span>
<span class="header-center">
<span class="header-center-upper-font">NaviDoc</span> <br>Clinical AI Assistant
</span>
<span class="header-side right">powered by GCP</span>
</div>
'''
)
with gr.Group():
with gr.Row():
msg = gr.Textbox(placeholder="👋Welcome to NaviDoc - Clinical AI Assistant! How can I assist you today?", label="", lines=2, show_label=False, elem_id="message_input")
with gr.Row(elem_id="button_row"):
diagnose_btn = gr.Button("🔍 Diagnose", elem_id="icon_button")
research_btn = gr.Button("🧠 Research", elem_id="submit_button")
clear_button = gr.Button("🧹 Clear Chat", elem_id="clear_button")
chatbot = gr.Chatbot(type='messages', elem_id="chatbox_container", show_label=False, scale=1)
gr.HTML("""
<script>
// Wait for DOM to load
window.addEventListener('DOMContentLoaded', function() {
// Observe changes to the chatbox container
const chatbox = document.querySelector('#chatbox_container');
if (chatbox) {
const observer = new MutationObserver(() => {
chatbox.scrollTop = chatbox.scrollHeight;
});
observer.observe(chatbox, { childList: true, subtree: true });
}
});
</script>
<link href="https://fonts.cdnfonts.com/css/navi" rel="stylesheet">
<link href="https://fonts.googleapis.com/css?family=Open+Sans:400,600&display=swap" rel="stylesheet">
""")
with gr.Column(scale=1, elem_id="sidebar"):
with gr.Row():
gr.Image(value="chatbot_logo_1.png", label="", height=120, width=120, elem_id="sidebar_logo", show_download_button=False, show_share_button=False, container=False)
gr.Markdown("""
### 🩺 Welcome to the Clinical Assistant
This chatbot helps medical professionals retrieve accurate, evidence-based insights from existing clinical transcriptions.
**Diagnose:**
<br> Helps to diagnose the diseases or condition based on case record of the patient and provide initial provisional diagnosis. It can also suggest investigations that can be undertaken.
**Research:**
<br> Provides medical information based on the health records stored.
**Disclaimer:** <span style="font-size:1.2em;">🏥⚠️</span>
<br>
<span style="background-color:#0077911C; padding:6px 10px; border-radius:8px; display:inline-block;">
The responses are for informational purposes only and are not a substitute for professional medical advice, diagnosis, or treatment.
</span>
""")
diagnose_btn.click(handle_diagnose, [msg, chatbot], [msg, chatbot])
research_btn.click(handle_research, [msg, chatbot], [msg, chatbot])
clear_button.click(lambda: [], None, chatbot, queue=False)
if __name__ == "__main__":
iface.launch(favicon_path="tab-favicon.svg", server_name="0.0.0.0", server_port=8080)
# iface.launch(favicon_path="tab-favicon.svg")