Skip to content

Commit 2134996

Browse files
committed
Add all prompts to a centralized prompt service
1 parent 681dec0 commit 2134996

File tree

3 files changed

+108
-112
lines changed

3 files changed

+108
-112
lines changed

server/api/views/conversations/views.py

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .models import Conversation, Message
1717
from .serializers import ConversationSerializer
1818
from ...services.tools.tools import tools, execute_tool
19+
from ...services.prompt_services import PromptTemplates
1920

2021

2122
@csrf_exempt
@@ -47,7 +48,7 @@ def extract_text(request: str) -> JsonResponse:
4748
messages=[
4849
{
4950
"role": "system",
50-
"content": "Give a brief description of this medicine: %s" % tokens,
51+
"content": PromptTemplates.get_medicine_description_prompt(tokens),
5152
}
5253
],
5354
max_tokens=500,
@@ -64,8 +65,10 @@ def get_tokens(string: str, encoding_name: str) -> str:
6465
output_string = encoding.decode(tokens)
6566
return output_string
6667

68+
6769
class OpenAIAPIException(APIException):
6870
"""Custom exception for OpenAI API errors."""
71+
6972
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
7073
default_detail = "An error occurred while communicating with the OpenAI API."
7174
default_code = "openai_api_error"
@@ -77,6 +80,7 @@ def __init__(self, detail=None, code=None):
7780
self.detail = {"error": self.default_detail}
7881
self.status_code = code or self.status_code
7982

83+
8084
class ConversationViewSet(viewsets.ModelViewSet):
8185
serializer_class = ConversationSerializer
8286
permission_classes = [IsAuthenticated]
@@ -93,26 +97,29 @@ def destroy(self, request, *args, **kwargs):
9397
self.perform_destroy(instance)
9498
return Response(status=status.HTTP_204_NO_CONTENT)
9599

96-
@action(detail=True, methods=['post'])
100+
@action(detail=True, methods=["post"])
97101
def continue_conversation(self, request, pk=None):
98102
conversation = self.get_object()
99-
user_message = request.data.get('message')
100-
page_context = request.data.get('page_context')
103+
user_message = request.data.get("message")
104+
page_context = request.data.get("page_context")
101105

102106
if not user_message:
103107
return Response({"error": "Message is required"}, status=400)
104108

105109
# Save user message
106-
Message.objects.create(conversation=conversation,
107-
content=user_message, is_user=True)
110+
Message.objects.create(
111+
conversation=conversation, content=user_message, is_user=True
112+
)
108113

109114
# Get ChatGPT response
110115
chatgpt_response = self.get_chatgpt_response(
111-
conversation, user_message, page_context)
116+
conversation, user_message, page_context
117+
)
112118

113119
# Save ChatGPT response
114-
Message.objects.create(conversation=conversation,
115-
content=chatgpt_response, is_user=False)
120+
Message.objects.create(
121+
conversation=conversation, content=chatgpt_response, is_user=False
122+
)
116123

117124
# Generate or update title if it's the first message or empty
118125
if conversation.messages.count() <= 2 or not conversation.title:
@@ -121,27 +128,35 @@ def continue_conversation(self, request, pk=None):
121128

122129
return Response({"response": chatgpt_response, "title": conversation.title})
123130

124-
@action(detail=True, methods=['patch'])
131+
@action(detail=True, methods=["patch"])
125132
def update_title(self, request, pk=None):
126133
conversation = self.get_object()
127-
new_title = request.data.get('title')
134+
new_title = request.data.get("title")
128135

129136
if not new_title:
130-
return Response({"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST)
137+
return Response(
138+
{"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST
139+
)
131140

132141
conversation.title = new_title
133142
conversation.save()
134143

135-
return Response({"status": "Title updated successfully", "title": conversation.title})
144+
return Response(
145+
{"status": "Title updated successfully", "title": conversation.title}
146+
)
136147

137148
def get_chatgpt_response(self, conversation, user_message, page_context=None):
138-
messages = [{
139-
"role": "system",
140-
"content": "You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional."
141-
}]
149+
messages = [
150+
{
151+
"role": "system",
152+
"content": PromptTemplates.get_conversation_system_prompt(),
153+
}
154+
]
142155

143156
if page_context:
144-
context_message = f"If applicable, please use the following content to ask questions. If not applicable, please answer to the best of your ability: {page_context}"
157+
context_message = PromptTemplates.get_conversation_page_context_prompt(
158+
page_context
159+
)
145160
messages.append({"role": "system", "content": context_message})
146161

147162
for msg in conversation.messages.all():
@@ -155,46 +170,50 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None):
155170
model="gpt-3.5-turbo",
156171
messages=messages,
157172
tools=tools,
158-
tool_choice="auto"
173+
tool_choice="auto",
159174
)
160175

161176
response_message = response.choices[0].message
162-
tool_calls = response_message.get('tool_calls', [])
177+
tool_calls = response_message.get("tool_calls", [])
163178

164179
if not tool_calls:
165-
return response_message['content']
166-
180+
return response_message["content"]
167181

168182
# Handle tool calls
169183
# Add the assistant's message with tool calls to the conversation
170-
messages.append({
171-
"role": "assistant",
172-
"content": response_message.get('content', ''),
173-
"tool_calls": tool_calls
174-
})
175-
184+
messages.append(
185+
{
186+
"role": "assistant",
187+
"content": response_message.get("content", ""),
188+
"tool_calls": tool_calls,
189+
}
190+
)
191+
176192
# Process each tool call
177193
for tool_call in tool_calls:
178-
tool_call_id = tool_call['id']
179-
tool_function_name = tool_call['function']['name']
180-
tool_arguments = json.loads(tool_call['function'].get('arguments', '{}'))
181-
194+
tool_call_id = tool_call["id"]
195+
tool_function_name = tool_call["function"]["name"]
196+
tool_arguments = json.loads(
197+
tool_call["function"].get("arguments", "{}")
198+
)
199+
182200
# Execute the tool
183201
results = execute_tool(tool_function_name, tool_arguments)
184-
202+
185203
# Add the tool response message
186-
messages.append({
187-
"role": "tool",
188-
"content": str(results), # Convert results to string
189-
"tool_call_id": tool_call_id
190-
})
191-
204+
messages.append(
205+
{
206+
"role": "tool",
207+
"content": str(results), # Convert results to string
208+
"tool_call_id": tool_call_id,
209+
}
210+
)
211+
192212
# Final API call with tool results
193213
final_response = openai.ChatCompletion.create(
194-
model="gpt-3.5-turbo",
195-
messages=messages
196-
)
197-
return final_response.choices[0].message['content']
214+
model="gpt-3.5-turbo", messages=messages
215+
)
216+
return final_response.choices[0].message["content"]
198217
except openai.error.OpenAIError as e:
199218
logging.error("OpenAI API Error: %s", str(e))
200219
raise OpenAIAPIException(detail=str(e))
@@ -206,14 +225,17 @@ def generate_title(self, conversation):
206225
# Get the first two messages
207226
messages = conversation.messages.all()[:2]
208227
context = "\n".join([msg.content for msg in messages])
209-
prompt = f"Based on the following conversation, generate a short, descriptive title (max 6 words):\n\n{context}"
228+
prompt = PromptTemplates.get_title_generation_user_prompt(context)
210229

211230
response = openai.ChatCompletion.create(
212231
model="gpt-3.5-turbo",
213232
messages=[
214-
{"role": "system", "content": "You are a helpful assistant that generates short, descriptive titles."},
215-
{"role": "user", "content": prompt}
216-
]
233+
{
234+
"role": "system",
235+
"content": PromptTemplates.get_title_generation_system_prompt(),
236+
},
237+
{"role": "user", "content": prompt},
238+
],
217239
)
218240

219-
return response.choices[0].message['content'].strip()
241+
return response.choices[0].message["content"].strip()

server/api/views/embeddings/embeddingsView.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,52 @@
66
from ...services.embedding_services import get_closest_embeddings
77
from ...services.conversions_services import convert_uuids
88
from ...services.openai_services import openAIServices
9+
from ...services.prompt_services import PromptTemplates
910
from django.utils.decorators import method_decorator
1011
from django.views.decorators.csrf import csrf_exempt
1112
import json
1213

1314

14-
@method_decorator(csrf_exempt, name='dispatch')
15+
@method_decorator(csrf_exempt, name="dispatch")
1516
class AskEmbeddingsAPIView(APIView):
1617
permission_classes = [IsAuthenticated]
1718

1819
def post(self, request, *args, **kwargs):
1920
try:
2021
user = request.user
21-
guid = request.query_params.get('guid')
22-
stream = request.query_params.get(
23-
'stream', 'false').lower() == 'true'
22+
guid = request.query_params.get("guid")
23+
stream = request.query_params.get("stream", "false").lower() == "true"
2424

25-
request_data = request.data.get('message', None)
25+
request_data = request.data.get("message", None)
2626
if not request_data:
27-
return Response({"error": "Message data is required."}, status=status.HTTP_400_BAD_REQUEST)
27+
return Response(
28+
{"error": "Message data is required."},
29+
status=status.HTTP_400_BAD_REQUEST,
30+
)
2831
message = str(request_data)
2932

3033
embeddings_results = get_closest_embeddings(
31-
user=user, message_data=message, guid=guid)
34+
user=user, message_data=message, guid=guid
35+
)
3236
embeddings_results = convert_uuids(embeddings_results)
3337

3438
prompt_texts = [
35-
f"[Start of INFO {i+1} === GUID: {obj['file_id']}, Page Number: {obj['page_number']}, Chunk Number: {obj['chunk_number']}, Text: {obj['text']} === End of INFO {i+1} ]" for i, obj in enumerate(embeddings_results)]
39+
f"[Start of INFO {i + 1} === GUID: {obj['file_id']}, Page Number: {obj['page_number']}, Chunk Number: {obj['chunk_number']}, Text: {obj['text']} === End of INFO {i + 1} ]"
40+
for i, obj in enumerate(embeddings_results)
41+
]
3642

3743
listOfEmbeddings = " ".join(prompt_texts)
3844

39-
prompt_text = (
40-
f"""You are an AI assistant tasked with providing detailed, well-structured responses based on the information provided in [PROVIDED-INFO]. Follow these guidelines strictly:
41-
1. Content: Use information contained within [PROVIDED-INFO] to answer the question.
42-
2. Organization: Structure your response with clear sections and paragraphs.
43-
3. Citations: After EACH sentence that uses information from [PROVIDED-INFO], include a citation in this exact format:***[{{file_id}}], Page {{page_number}}, Chunk {{chunk_number}}*** . Only use citations that correspond to the information you're presenting.
44-
4. Clarity: Ensure your answer is well-structured and easy to follow.
45-
5. Direct Response: Answer the user's question directly without unnecessary introductions or filler phrases.
46-
Here's an example of the required response format:
47-
________________________________________
48-
See's Candy in the context of sales during a specific event. The candy counters rang up 2,690 individual sales on a Friday, and an additional 3,931 transactions on a Saturday ***[16s848as-vcc1-85sd-r196-7f820a4s9de1, Page 5, Chunk 26]***.
49-
People like the consumption of fudge and peanut brittle the most ***[130714d7-b9c1-4sdf-b146-fdsf854cad4f, Page 9, Chunk 19]***.
50-
Here is the history of See's Candy: the company was purchased in 1972, and its products have not been materially altered in 101 years ***[895sdsae-b7v5-416f-c84v-7f9784dc01e1, Page 2, Chunk 13]***.
51-
Bipolar disorder treatment often involves mood stabilizers. Lithium is a commonly prescribed mood stabilizer effective in reducing manic episodes ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 29, Chunk 122]***. For acute hypomania or mild to moderate mania, initial treatment with risperidone or olanzapine monotherapy is suggested ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 24, Chunk 101]***.
52-
________________________________________
53-
Please provide your response to the user's question following these guidelines precisely.
54-
[PROVIDED-INFO] = {listOfEmbeddings}"""
55-
)
45+
prompt_text = PromptTemplates.get_embeddings_query_prompt(listOfEmbeddings)
5646

5747
if stream:
48+
5849
def stream_generator():
5950
try:
6051
last_chunk = ""
61-
for chunk in openAIServices.openAI(message, prompt_text, stream=True, raw_stream=False):
52+
for chunk in openAIServices.openAI(
53+
message, prompt_text, stream=True, raw_stream=False
54+
):
6255
# Format as Server-Sent Events for better client handling
6356
if chunk and chunk != last_chunk:
6457
last_chunk = chunk
@@ -72,27 +65,29 @@ def stream_generator():
7265
yield f"data: {error_data}\n\n"
7366

7467
response = StreamingHttpResponse(
75-
stream_generator(),
76-
content_type='text/event-stream'
68+
stream_generator(), content_type="text/event-stream"
7769
)
7870
# Add CORS and caching headers for streaming
79-
response['Cache-Control'] = 'no-cache'
80-
response['Access-Control-Allow-Origin'] = '*'
71+
response["Cache-Control"] = "no-cache"
72+
response["Access-Control-Allow-Origin"] = "*"
8173
# Disable nginx buffering if behind nginx
82-
response['X-Accel-Buffering'] = 'no'
74+
response["X-Accel-Buffering"] = "no"
8375
return response
8476
# Non-streaming response
8577
answer = openAIServices.openAI(
86-
userMessage=message,
87-
prompt=prompt_text,
88-
stream=False
78+
userMessage=message, prompt=prompt_text, stream=False
79+
)
80+
return Response(
81+
{
82+
"question": message,
83+
"llm_response": answer,
84+
"embeddings_info": embeddings_results,
85+
"sent_to_llm": prompt_text,
86+
},
87+
status=status.HTTP_200_OK,
8988
)
90-
return Response({
91-
"question": message,
92-
"llm_response": answer,
93-
"embeddings_info": embeddings_results,
94-
"sent_to_llm": prompt_text,
95-
}, status=status.HTTP_200_OK)
9689

9790
except Exception as e:
98-
return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
91+
return Response(
92+
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
93+
)

server/api/views/text_extraction/views.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from django.views.decorators.csrf import csrf_exempt
1010

1111
from ...services.openai_services import openAIServices
12+
from ...services.prompt_services import PromptTemplates
1213
from api.models.model_embeddings import Embeddings
1314

1415
# This is to use openai to extract the rules to save cost
@@ -37,29 +38,7 @@ class RuleExtractionAPIOpenAIView(APIView):
3738

3839
def get(self, request):
3940
try:
40-
user_prompt = """
41-
You're analyzing medical text from multiple sources. Each chunk is labeled [chunk-X].
42-
43-
Act as a seasoned physician or medical professional who treats patients with bipolar disorder.
44-
45-
Identify rules for medication inclusion or exclusion based on medical history or concerns.
46-
47-
For each rule you find, return a JSON object using the following format:
48-
49-
{
50-
"rule": "<condition or concern>",
51-
"type": "INCLUDE" or "EXCLUDE",
52-
"reason": "<short explanation for why this rule applies>",
53-
"medications": ["<medication 1>", "<medication 2>", ...],
54-
"source": "<chunk-X>"
55-
}
56-
57-
Only include rules that are explicitly stated or strongly implied in the chunk.
58-
59-
Only use the chunks provided. If no rule is found in a chunk, skip it.
60-
61-
Return the entire output as a JSON array.
62-
"""
41+
user_prompt = PromptTemplates.get_text_extraction_prompt()
6342

6443
guid = request.query_params.get("guid")
6544
query = Embeddings.objects.filter(upload_file__guid=guid)

0 commit comments

Comments
 (0)