diff --git a/server/api/services/prompt_services.py b/server/api/services/prompt_services.py new file mode 100644 index 00000000..73f58707 --- /dev/null +++ b/server/api/services/prompt_services.py @@ -0,0 +1,97 @@ +""" +Centralized prompt management for the application. +Contains all prompts used across different services. +""" + + +class PromptTemplates: + """Central repository for all prompt templates used in the application.""" + + TEXT_EXTRACTION_RULE_EXTRACTION = """ +You're analyzing medical text from multiple sources. Each chunk is labeled [chunk-X]. + +Act as a seasoned physician or medical professional who treats patients with bipolar disorder. + +Identify rules for medication inclusion or exclusion based on medical history or concerns. + +For each rule you find, return a JSON object using the following format: + +{ + "rule": "", + "type": "INCLUDE" or "EXCLUDE", + "reason": "", + "medications": ["", "", ...], + "source": "" +} + +Only include rules that are explicitly stated or strongly implied in the chunk. + +Only use the chunks provided. If no rule is found in a chunk, skip it. + +Return the entire output as a JSON array. +""" + + EMBEDDINGS_QUERY_RESPONSE = """You are an AI assistant tasked with providing detailed, well-structured responses based on the information provided in [PROVIDED-INFO]. Follow these guidelines strictly: +1. Content: Use information contained within [PROVIDED-INFO] to answer the question. +2. Organization: Structure your response with clear sections and paragraphs. +3. Citations: After EACH sentence that uses information from [PROVIDED-INFO], include a citation in this exact format:***[{{file_id}}], Page {{page_number}}, Chunk {{chunk_number}}*** . Only use citations that correspond to the information you're presenting. +4. Clarity: Ensure your answer is well-structured and easy to follow. +5. Direct Response: Answer the user's question directly without unnecessary introductions or filler phrases. +Here's an example of the required response format: +________________________________________ +See's Candy in the context of sales during a specific event. The candy counters rang up 2,690 individual sales on a Friday, and an additional 3,931 transactions on a Saturday ***[16s848as-vcc1-85sd-r196-7f820a4s9de1, Page 5, Chunk 26]***. +People like the consumption of fudge and peanut brittle the most ***[130714d7-b9c1-4sdf-b146-fdsf854cad4f, Page 9, Chunk 19]***. +Here is the history of See's Candy: the company was purchased in 1972, and its products have not been materially altered in 101 years ***[895sdsae-b7v5-416f-c84v-7f9784dc01e1, Page 2, Chunk 13]***. +Bipolar disorder treatment often involves mood stabilizers. Lithium is a commonly prescribed mood stabilizer effective in reducing manic episodes ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 29, Chunk 122]***. For acute hypomania or mild to moderate mania, initial treatment with risperidone or olanzapine monotherapy is suggested ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 24, Chunk 101]***. +________________________________________ +Please provide your response to the user's question following these guidelines precisely. +[PROVIDED-INFO] = {listOfEmbeddings}""" + + CONVERSATION_SYSTEM_PROMPT = """You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional.""" + + CONVERSATION_PAGE_CONTEXT_PROMPT = """If applicable, please use the following content to ask questions. If not applicable, please answer to the best of your ability: {page_context}""" + + MEDICINE_DESCRIPTION_PROMPT = """Give a brief description of this medicine: %s""" + + TITLE_GENERATION_SYSTEM_PROMPT = ( + """You are a helpful assistant that generates short, descriptive titles.""" + ) + + TITLE_GENERATION_USER_PROMPT = """Based on the following conversation, generate a short, descriptive title (max 6 words): + +{context}""" + + @classmethod + def get_text_extraction_prompt(cls): + """Get the text extraction rule extraction prompt.""" + return cls.TEXT_EXTRACTION_RULE_EXTRACTION + + @classmethod + def get_embeddings_query_prompt(cls, list_of_embeddings): + """Get the embeddings query response prompt with embedded data.""" + return cls.EMBEDDINGS_QUERY_RESPONSE.format(listOfEmbeddings=list_of_embeddings) + + @classmethod + def get_conversation_system_prompt(cls): + """Get the conversation system prompt.""" + return cls.CONVERSATION_SYSTEM_PROMPT + + @classmethod + def get_conversation_page_context_prompt(cls, page_context): + """Get the conversation page context prompt.""" + return cls.CONVERSATION_PAGE_CONTEXT_PROMPT.format(page_context=page_context) + + @classmethod + def get_medicine_description_prompt(cls, tokens): + """Get the medicine description prompt.""" + return cls.MEDICINE_DESCRIPTION_PROMPT % tokens + + @classmethod + def get_title_generation_system_prompt(cls): + """Get the title generation system prompt.""" + return cls.TITLE_GENERATION_SYSTEM_PROMPT + + @classmethod + def get_title_generation_user_prompt(cls, context): + """Get the title generation user prompt.""" + return cls.TITLE_GENERATION_USER_PROMPT.format(context=context) diff --git a/server/api/views/conversations/views.py b/server/api/views/conversations/views.py index d46f8222..3c1efef3 100644 --- a/server/api/views/conversations/views.py +++ b/server/api/views/conversations/views.py @@ -16,6 +16,7 @@ from .models import Conversation, Message from .serializers import ConversationSerializer from ...services.tools.tools import tools, execute_tool +from ...services.prompt_services import PromptTemplates @csrf_exempt @@ -47,7 +48,7 @@ def extract_text(request: str) -> JsonResponse: messages=[ { "role": "system", - "content": "Give a brief description of this medicine: %s" % tokens, + "content": PromptTemplates.get_medicine_description_prompt(tokens), } ], max_tokens=500, @@ -64,8 +65,10 @@ def get_tokens(string: str, encoding_name: str) -> str: output_string = encoding.decode(tokens) return output_string + class OpenAIAPIException(APIException): """Custom exception for OpenAI API errors.""" + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR default_detail = "An error occurred while communicating with the OpenAI API." default_code = "openai_api_error" @@ -77,6 +80,7 @@ def __init__(self, detail=None, code=None): self.detail = {"error": self.default_detail} self.status_code = code or self.status_code + class ConversationViewSet(viewsets.ModelViewSet): serializer_class = ConversationSerializer permission_classes = [IsAuthenticated] @@ -93,26 +97,29 @@ def destroy(self, request, *args, **kwargs): self.perform_destroy(instance) return Response(status=status.HTTP_204_NO_CONTENT) - @action(detail=True, methods=['post']) + @action(detail=True, methods=["post"]) def continue_conversation(self, request, pk=None): conversation = self.get_object() - user_message = request.data.get('message') - page_context = request.data.get('page_context') + user_message = request.data.get("message") + page_context = request.data.get("page_context") if not user_message: return Response({"error": "Message is required"}, status=400) # Save user message - Message.objects.create(conversation=conversation, - content=user_message, is_user=True) + Message.objects.create( + conversation=conversation, content=user_message, is_user=True + ) # Get ChatGPT response chatgpt_response = self.get_chatgpt_response( - conversation, user_message, page_context) + conversation, user_message, page_context + ) # Save ChatGPT response - Message.objects.create(conversation=conversation, - content=chatgpt_response, is_user=False) + Message.objects.create( + conversation=conversation, content=chatgpt_response, is_user=False + ) # Generate or update title if it's the first message or empty if conversation.messages.count() <= 2 or not conversation.title: @@ -121,27 +128,35 @@ def continue_conversation(self, request, pk=None): return Response({"response": chatgpt_response, "title": conversation.title}) - @action(detail=True, methods=['patch']) + @action(detail=True, methods=["patch"]) def update_title(self, request, pk=None): conversation = self.get_object() - new_title = request.data.get('title') + new_title = request.data.get("title") if not new_title: - return Response({"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST + ) conversation.title = new_title conversation.save() - return Response({"status": "Title updated successfully", "title": conversation.title}) + return Response( + {"status": "Title updated successfully", "title": conversation.title} + ) def get_chatgpt_response(self, conversation, user_message, page_context=None): - messages = [{ - "role": "system", - "content": "You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional." - }] + messages = [ + { + "role": "system", + "content": PromptTemplates.get_conversation_system_prompt(), + } + ] if page_context: - context_message = f"If applicable, please use the following content to ask questions. If not applicable, please answer to the best of your ability: {page_context}" + context_message = PromptTemplates.get_conversation_page_context_prompt( + page_context + ) messages.append({"role": "system", "content": context_message}) for msg in conversation.messages.all(): @@ -155,46 +170,50 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None): model="gpt-3.5-turbo", messages=messages, tools=tools, - tool_choice="auto" + tool_choice="auto", ) response_message = response.choices[0].message - tool_calls = response_message.get('tool_calls', []) + tool_calls = response_message.get("tool_calls", []) if not tool_calls: - return response_message['content'] - + return response_message["content"] # Handle tool calls # Add the assistant's message with tool calls to the conversation - messages.append({ - "role": "assistant", - "content": response_message.get('content', ''), - "tool_calls": tool_calls - }) - + messages.append( + { + "role": "assistant", + "content": response_message.get("content", ""), + "tool_calls": tool_calls, + } + ) + # Process each tool call for tool_call in tool_calls: - tool_call_id = tool_call['id'] - tool_function_name = tool_call['function']['name'] - tool_arguments = json.loads(tool_call['function'].get('arguments', '{}')) - + tool_call_id = tool_call["id"] + tool_function_name = tool_call["function"]["name"] + tool_arguments = json.loads( + tool_call["function"].get("arguments", "{}") + ) + # Execute the tool results = execute_tool(tool_function_name, tool_arguments) - + # Add the tool response message - messages.append({ - "role": "tool", - "content": str(results), # Convert results to string - "tool_call_id": tool_call_id - }) - + messages.append( + { + "role": "tool", + "content": str(results), # Convert results to string + "tool_call_id": tool_call_id, + } + ) + # Final API call with tool results final_response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=messages - ) - return final_response.choices[0].message['content'] + model="gpt-3.5-turbo", messages=messages + ) + return final_response.choices[0].message["content"] except openai.error.OpenAIError as e: logging.error("OpenAI API Error: %s", str(e)) raise OpenAIAPIException(detail=str(e)) @@ -206,14 +225,17 @@ def generate_title(self, conversation): # Get the first two messages messages = conversation.messages.all()[:2] context = "\n".join([msg.content for msg in messages]) - prompt = f"Based on the following conversation, generate a short, descriptive title (max 6 words):\n\n{context}" + prompt = PromptTemplates.get_title_generation_user_prompt(context) response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ - {"role": "system", "content": "You are a helpful assistant that generates short, descriptive titles."}, - {"role": "user", "content": prompt} - ] + { + "role": "system", + "content": PromptTemplates.get_title_generation_system_prompt(), + }, + {"role": "user", "content": prompt}, + ], ) - return response.choices[0].message['content'].strip() + return response.choices[0].message["content"].strip() diff --git a/server/api/views/embeddings/embeddingsView.py b/server/api/views/embeddings/embeddingsView.py index 9469bb09..fdaf6e9d 100644 --- a/server/api/views/embeddings/embeddingsView.py +++ b/server/api/views/embeddings/embeddingsView.py @@ -6,59 +6,52 @@ from ...services.embedding_services import get_closest_embeddings from ...services.conversions_services import convert_uuids from ...services.openai_services import openAIServices +from ...services.prompt_services import PromptTemplates from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt import json -@method_decorator(csrf_exempt, name='dispatch') +@method_decorator(csrf_exempt, name="dispatch") class AskEmbeddingsAPIView(APIView): permission_classes = [IsAuthenticated] def post(self, request, *args, **kwargs): try: user = request.user - guid = request.query_params.get('guid') - stream = request.query_params.get( - 'stream', 'false').lower() == 'true' + guid = request.query_params.get("guid") + stream = request.query_params.get("stream", "false").lower() == "true" - request_data = request.data.get('message', None) + request_data = request.data.get("message", None) if not request_data: - return Response({"error": "Message data is required."}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "Message data is required."}, + status=status.HTTP_400_BAD_REQUEST, + ) message = str(request_data) embeddings_results = get_closest_embeddings( - user=user, message_data=message, guid=guid) + user=user, message_data=message, guid=guid + ) embeddings_results = convert_uuids(embeddings_results) prompt_texts = [ - f"[Start of INFO {i+1} === GUID: {obj['file_id']}, Page Number: {obj['page_number']}, Chunk Number: {obj['chunk_number']}, Text: {obj['text']} === End of INFO {i+1} ]" for i, obj in enumerate(embeddings_results)] + f"[Start of INFO {i + 1} === GUID: {obj['file_id']}, Page Number: {obj['page_number']}, Chunk Number: {obj['chunk_number']}, Text: {obj['text']} === End of INFO {i + 1} ]" + for i, obj in enumerate(embeddings_results) + ] listOfEmbeddings = " ".join(prompt_texts) - prompt_text = ( - f"""You are an AI assistant tasked with providing detailed, well-structured responses based on the information provided in [PROVIDED-INFO]. Follow these guidelines strictly: - 1. Content: Use information contained within [PROVIDED-INFO] to answer the question. - 2. Organization: Structure your response with clear sections and paragraphs. - 3. Citations: After EACH sentence that uses information from [PROVIDED-INFO], include a citation in this exact format:***[{{file_id}}], Page {{page_number}}, Chunk {{chunk_number}}*** . Only use citations that correspond to the information you're presenting. - 4. Clarity: Ensure your answer is well-structured and easy to follow. - 5. Direct Response: Answer the user's question directly without unnecessary introductions or filler phrases. - Here's an example of the required response format: - ________________________________________ - See's Candy in the context of sales during a specific event. The candy counters rang up 2,690 individual sales on a Friday, and an additional 3,931 transactions on a Saturday ***[16s848as-vcc1-85sd-r196-7f820a4s9de1, Page 5, Chunk 26]***. - People like the consumption of fudge and peanut brittle the most ***[130714d7-b9c1-4sdf-b146-fdsf854cad4f, Page 9, Chunk 19]***. - Here is the history of See's Candy: the company was purchased in 1972, and its products have not been materially altered in 101 years ***[895sdsae-b7v5-416f-c84v-7f9784dc01e1, Page 2, Chunk 13]***. - Bipolar disorder treatment often involves mood stabilizers. Lithium is a commonly prescribed mood stabilizer effective in reducing manic episodes ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 29, Chunk 122]***. For acute hypomania or mild to moderate mania, initial treatment with risperidone or olanzapine monotherapy is suggested ***[b99988ac-e3b0-4d22-b978-215e814807f4, Page 24, Chunk 101]***. - ________________________________________ - Please provide your response to the user's question following these guidelines precisely. - [PROVIDED-INFO] = {listOfEmbeddings}""" - ) + prompt_text = PromptTemplates.get_embeddings_query_prompt(listOfEmbeddings) if stream: + def stream_generator(): try: last_chunk = "" - for chunk in openAIServices.openAI(message, prompt_text, stream=True, raw_stream=False): + for chunk in openAIServices.openAI( + message, prompt_text, stream=True, raw_stream=False + ): # Format as Server-Sent Events for better client handling if chunk and chunk != last_chunk: last_chunk = chunk @@ -72,27 +65,29 @@ def stream_generator(): yield f"data: {error_data}\n\n" response = StreamingHttpResponse( - stream_generator(), - content_type='text/event-stream' + stream_generator(), content_type="text/event-stream" ) # Add CORS and caching headers for streaming - response['Cache-Control'] = 'no-cache' - response['Access-Control-Allow-Origin'] = '*' + response["Cache-Control"] = "no-cache" + response["Access-Control-Allow-Origin"] = "*" # Disable nginx buffering if behind nginx - response['X-Accel-Buffering'] = 'no' + response["X-Accel-Buffering"] = "no" return response # Non-streaming response answer = openAIServices.openAI( - userMessage=message, - prompt=prompt_text, - stream=False + userMessage=message, prompt=prompt_text, stream=False + ) + return Response( + { + "question": message, + "llm_response": answer, + "embeddings_info": embeddings_results, + "sent_to_llm": prompt_text, + }, + status=status.HTTP_200_OK, ) - return Response({ - "question": message, - "llm_response": answer, - "embeddings_info": embeddings_results, - "sent_to_llm": prompt_text, - }, status=status.HTTP_200_OK) except Exception as e: - return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) diff --git a/server/api/views/text_extraction/views.py b/server/api/views/text_extraction/views.py index e4122851..365e30cf 100644 --- a/server/api/views/text_extraction/views.py +++ b/server/api/views/text_extraction/views.py @@ -1,4 +1,3 @@ -import os import json import re @@ -8,166 +7,40 @@ from rest_framework import status from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt -import anthropic from ...services.openai_services import openAIServices +from ...services.prompt_services import PromptTemplates from api.models.model_embeddings import Embeddings -USER_PROMPT = """ -I'm creating a system to analyze medical research. It processes peer-reviewed papers to extract key details - -Act as a seasoned physician or medical professional who treat patients with bipolar disorder - -Identify rules for medication inclusion or exclusion based on medical history or concerns - -Return an output with the same structure as these examples: - -The rule is history of suicide attempts. The type of rule is "INCLUDE". The reason is lithium is the -only medication on the market that has been proven to reduce suicidality in patients with bipolar disorder. -The medications for this rule are lithium. - -The rule is weight gain concerns. The type of rule is "EXCLUDE". The reason is Seroquel, Risperdal, Abilify, and -Zyprexa are known for causing weight gain. The medications for this rule are Quetiapine, Aripiprazole, Olanzapine, Risperidone -} -""" - - -def anthropic_citations(client: anthropic.Client, user_prompt: str, content_chunks: list) -> tuple: - """ - Sends a message to Anthropic Citations and extract and format the response - - Parameters - ---------- - client: An instance of the Anthropic API client used to make the request - user_prompt: The user's question or instruction to be processed by the model - content_chunks: A list of text chunks that provide context for the model to use during generation - - Returns - ------- - tuple - - """ - - - message = client.messages.create( - model="claude-3-5-haiku-20241022", - max_tokens=1024, - messages=[ - { - "role": "user", - "content": [ - { - "type": "document", - "source": { - "type": "content", - "content": content_chunks - }, - "citations": {"enabled": True} - }, - - { - "type": "text", - "text": user_prompt - } - ] - } - ], - ) - - # Response Structure: https://docs.anthropic.com/en/docs/build-with-claude/citations#response-structure - - text = [] - cited_text = [] - for content in message.to_dict()['content']: - text.append(content['text']) - if 'citations' in content.keys(): - text.append(" ".join( - [f"<{citation['start_block_index']} - {citation['end_block_index']}>" for citation in content['citations']])) - cited_text.append(" ".join( - [f"<{citation['start_block_index']} - {citation['end_block_index']}> {citation['cited_text']}" for citation in content['citations']])) - - texts = " ".join(text) - cited_texts = " ".join(cited_text) - - return texts, cited_texts - - -@method_decorator(csrf_exempt, name='dispatch') -class RuleExtractionAPIView(APIView): - - permission_classes = [IsAuthenticated] - - def get(self, request): - try: - - client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - - guid = request.query_params.get('guid') - - query = Embeddings.objects.filter(upload_file__guid=guid) - - # TODO: Format into the Anthropic API"s expected input format in the anthropic_citations function - chunks = [{"type": "text", "text": chunk.text} for chunk in query] - - texts, cited_texts = anthropic_citations(client, USER_PROMPT, chunks) - - - return Response({"texts": texts, "cited_texts": cited_texts}, status=status.HTTP_200_OK) - - except Exception as e: - return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - # This is to use openai to extract the rules to save cost + def openai_extraction(content_chunks, user_prompt): """ Prepares the OpenAI input and returns the extracted text. """ - combined_text = "\n\n".join(chunk['text'] for chunk in content_chunks) + combined_text = "\n\n".join(chunk["text"] for chunk in content_chunks) result = openAIServices.openAI( userMessage=combined_text, prompt=user_prompt, model="gpt-4o-mini", temp=0.0, - stream=False + stream=False, ) return result -@method_decorator(csrf_exempt, name='dispatch') +@method_decorator(csrf_exempt, name="dispatch") class RuleExtractionAPIOpenAIView(APIView): permission_classes = [IsAuthenticated] def get(self, request): try: - user_prompt = """ - You're analyzing medical text from multiple sources. Each chunk is labeled [chunk-X]. - - Act as a seasoned physician or medical professional who treats patients with bipolar disorder. - - Identify rules for medication inclusion or exclusion based on medical history or concerns. - - For each rule you find, return a JSON object using the following format: - - { - "rule": "", - "type": "INCLUDE" or "EXCLUDE", - "reason": "", - "medications": ["", "", ...], - "source": "" - } - - Only include rules that are explicitly stated or strongly implied in the chunk. - - Only use the chunks provided. If no rule is found in a chunk, skip it. - - Return the entire output as a JSON array. - """ + user_prompt = PromptTemplates.get_text_extraction_prompt() - guid = request.query_params.get('guid') + guid = request.query_params.get("guid") query = Embeddings.objects.filter(upload_file__guid=guid) chunks = [ {"type": "text", "text": f"[chunk-{i}] {chunk.text}"} @@ -175,13 +48,11 @@ def get(self, request): ] output_text = openai_extraction(chunks, user_prompt) - cleaned_text = re.sub(r"^```json|```$", "", - output_text.strip()).strip() + cleaned_text = re.sub(r"^```json|```$", "", output_text.strip()).strip() rules = json.loads(cleaned_text) # Attach chunk_number and chunk_text to each rule - chunk_lookup = {f"chunk-{i}": chunk.text for i, - chunk in enumerate(query)} + chunk_lookup = {f"chunk-{i}": chunk.text for i, chunk in enumerate(query)} for rule in rules: source = rule.get("source", "").strip("[]") # e.g. chunk-63 if source.startswith("chunk-"): @@ -192,4 +63,6 @@ def get(self, request): return Response({"rules": rules}, status=status.HTTP_200_OK) except Exception as e: - return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response( + {"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR + )