11from django .http import JsonResponse
22from django .views .decorators .http import require_POST
3+ from langchain import QAWithSourcesChain
34
45from api .utils import get_vector_store
5- from api .utils .make_chain import getConversationRetrievalChain
6+ from api .utils .make_chain import getConversationRetrievalChain , getRetrievalQAWithSourcesChain
67import json
78from django .views .decorators .csrf import csrf_exempt
89from api .interfaces import StoreOptions
1314import logging
1415import traceback
1516from web .services .chat_history_service import get_chat_history_for_retrieval_chain
17+ import os
18+
19+ from dotenv import load_dotenv
20+ load_dotenv ()
1621
1722logger = logging .getLogger (__name__ )
1823
@@ -36,12 +41,8 @@ def chat(request):
3641 sanitized_question = question .strip ().replace ('\n ' , ' ' )
3742
3843 vector_store = get_vector_store (StoreOptions (namespace = namespace ))
39- chain = getConversationRetrievalChain (vector_store , mode , initial_prompt , memory_key = session_id )
4044
41- # To avoid fetching an excessively large amount of history data from the database, set a limit on the maximum number of records that can be retrieved in a single query.
42- chat_history = get_chat_history_for_retrieval_chain (session_id , limit = 40 )
43- response = chain ({"question" : sanitized_question , "chat_history" : chat_history }, return_only_outputs = True )
44- response_text = response ['answer' ]
45+ response_text = get_completion_response (vector_store = vector_store , initial_prompt = initial_prompt ,mode = mode , sanitized_question = sanitized_question , session_id = session_id )
4546
4647 ChatHistory .objects .bulk_create ([
4748 ChatHistory (
@@ -68,4 +69,19 @@ def chat(request):
6869 except Exception as e :
6970 logger .error (str (e ))
7071 logger .error (traceback .format_exc ())
71- return JsonResponse ({'error' : 'An error occurred' }, status = 500 )
72+ return JsonResponse ({'error' : 'An error occurred' }, status = 500 )
73+
74+
75+ def get_completion_response (vector_store , mode , initial_prompt , sanitized_question , session_id ):
76+ chain_type = os .getenv ("CHAIN_TYPE" , "conversation_retrieval" )
77+ chain : QAWithSourcesChain
78+ if chain_type == 'retrieval_qa' :
79+ chain = getRetrievalQAWithSourcesChain (vector_store , mode , initial_prompt )
80+ response = chain ({"question" : sanitized_question }, return_only_outputs = True )
81+ response_text = response ['answer' ]
82+ elif chain_type == 'conversation_retrieval' :
83+ chain = getConversationRetrievalChain (vector_store , mode , initial_prompt )
84+ chat_history = get_chat_history_for_retrieval_chain (session_id , limit = 40 )
85+ response = chain ({"question" : sanitized_question , "chat_history" : chat_history }, return_only_outputs = True )
86+ response_text = response ['answer' ]
87+ return response_text
0 commit comments