Skip to content
This repository was archived by the owner on Jan 5, 2025. It is now read-only.

Commit b8a8a1d

Browse files
authored
Merge pull request #142 from codebanesr/enhancement/chain_selector
Enhancement/chain selector
2 parents fb02a75 + 97848fa commit b8a8a1d

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ If you want to switch from Pinecone to Qdrant, you can set the following environ
108108
- `STORE`: The store to use to store embeddings. Can be `qdrant` or `pinecone`.
109109

110110

111+
#### Optional [To modify the chat behaviour]
112+
113+
`CHAIN_TYPE` = The type of chain to use: `conversation_retrieval` | `retrieval_qa`
114+
115+
- `retrieval_qa` -> [Learn more](https://python.langchain.com/docs/use_cases/question_answering/how_to/vector_db_qa)
116+
- `conversation_retrieval` -> [Learn more](https://python.langchain.com/docs/use_cases/question_answering/how_to/chat_vector_db)
117+
118+
111119
> Note: for pincone db, make sure that the dimension is equal to 1536
112120
113121
- Navigate to the repository folder and run the following command (for MacOS or Linux):

dj_backend_server/api/utils/make_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def getRetrievalQAWithSourcesChain(vector_store: VectorStore, mode, initial_prom
3030
return chain
3131

3232

33-
def getConversationRetrievalChain(vector_store: VectorStore, mode, initial_prompt: str, memory_key: str):
33+
def getConversationRetrievalChain(vector_store: VectorStore, mode, initial_prompt: str):
3434
llm = get_llm()
3535
template = get_qa_prompt_by_mode(mode, initial_prompt=initial_prompt)
3636
prompt = PromptTemplate.from_template(template)

dj_backend_server/api/views/views_chat.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from django.http import JsonResponse
22
from django.views.decorators.http import require_POST
3+
from langchain import QAWithSourcesChain
34

45
from api.utils import get_vector_store
5-
from api.utils.make_chain import getConversationRetrievalChain
6+
from api.utils.make_chain import getConversationRetrievalChain, getRetrievalQAWithSourcesChain
67
import json
78
from django.views.decorators.csrf import csrf_exempt
89
from api.interfaces import StoreOptions
@@ -13,6 +14,10 @@
1314
import logging
1415
import traceback
1516
from web.services.chat_history_service import get_chat_history_for_retrieval_chain
17+
import os
18+
19+
from dotenv import load_dotenv
20+
load_dotenv()
1621

1722
logger = logging.getLogger(__name__)
1823

@@ -36,12 +41,8 @@ def chat(request):
3641
sanitized_question = question.strip().replace('\n', ' ')
3742

3843
vector_store = get_vector_store(StoreOptions(namespace=namespace))
39-
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt, memory_key=session_id)
4044

41-
# To avoid fetching an excessively large amount of history data from the database, set a limit on the maximum number of records that can be retrieved in a single query.
42-
chat_history = get_chat_history_for_retrieval_chain(session_id, limit=40)
43-
response = chain({"question": sanitized_question, "chat_history": chat_history }, return_only_outputs=True)
44-
response_text = response['answer']
45+
response_text = get_completion_response(vector_store=vector_store, initial_prompt=initial_prompt,mode=mode, sanitized_question=sanitized_question, session_id=session_id)
4546

4647
ChatHistory.objects.bulk_create([
4748
ChatHistory(
@@ -68,4 +69,19 @@ def chat(request):
6869
except Exception as e:
6970
logger.error(str(e))
7071
logger.error(traceback.format_exc())
71-
return JsonResponse({'error': 'An error occurred'}, status=500)
72+
return JsonResponse({'error': 'An error occurred'}, status=500)
73+
74+
75+
def get_completion_response(vector_store, mode, initial_prompt, sanitized_question, session_id):
76+
chain_type = os.getenv("CHAIN_TYPE", "conversation_retrieval")
77+
chain: QAWithSourcesChain
78+
if chain_type == 'retrieval_qa':
79+
chain = getRetrievalQAWithSourcesChain(vector_store, mode, initial_prompt)
80+
response = chain({"question": sanitized_question}, return_only_outputs=True)
81+
response_text = response['answer']
82+
elif chain_type == 'conversation_retrieval':
83+
chain = getConversationRetrievalChain(vector_store, mode, initial_prompt)
84+
chat_history = get_chat_history_for_retrieval_chain(session_id, limit=40)
85+
response = chain({"question": sanitized_question, "chat_history": chat_history}, return_only_outputs=True)
86+
response_text = response['answer']
87+
return response_text

0 commit comments

Comments
 (0)