22from utils import opensearch , secret
33from langchain_community .embeddings import BedrockEmbeddings
44from langchain_community .vectorstores import OpenSearchVectorSearch
5- from langchain .chains import RetrievalQA
6- from langchain .prompts import PromptTemplate
7- from langchain .llms .bedrock import Bedrock
5+ from langchain .chains .combine_documents import create_stuff_documents_chain
6+ from langchain .chains import create_retrieval_chain
7+ from langchain_core .prompts import ChatPromptTemplate
8+ from langchain_community .chat_models import BedrockChat
89import boto3
910from loguru import logger
1011import sys
1819
1920def parse_args ():
2021 parser = argparse .ArgumentParser ()
21- parser .add_argument ("--ask" , type = str , default = "What is <3?" )
22+ parser .add_argument ("--ask" , type = str , default = "What is the meaning of <3?" )
2223 parser .add_argument ("--index" , type = str , default = "rag" )
2324 parser .add_argument ("--region" , type = str , default = "us-east-1" )
24- parser .add_argument ("--bedrock-model-id" , type = str , default = "anthropic.claude-v2 " )
25+ parser .add_argument ("--bedrock-model-id" , type = str , default = "anthropic.claude-3-sonnet-20240229-v1:0 " )
2526 parser .add_argument ("--bedrock-embedding-model-id" , type = str , default = "amazon.titan-embed-text-v1" )
2627
2728 return parser .parse_known_args ()
@@ -51,7 +52,7 @@ def create_opensearch_vector_search_client(index_name, opensearch_password, bedr
5152
5253
5354def create_bedrock_llm (bedrock_client , model_version_id ):
54- bedrock_llm = Bedrock (
55+ bedrock_llm = BedrockChat (
5556 model_id = model_version_id ,
5657 client = bedrock_client ,
5758 model_kwargs = {'temperature' : 0 }
@@ -60,12 +61,14 @@ def create_bedrock_llm(bedrock_client, model_version_id):
6061
6162
6263def main ():
63- logger .info ("Starting" )
64+ logger .info ("Starting... " )
6465 args , _ = parse_args ()
6566 region = args .region
6667 index_name = args .index
6768 bedrock_model_id = args .bedrock_model_id
6869 bedrock_embedding_model_id = args .bedrock_embedding_model_id
70+ question = args .ask
71+ logger .info (f"Question provided: { question } " )
6972
7073 # Creating all clients for chain
7174 bedrock_client = get_bedrock_client (region )
@@ -76,39 +79,31 @@ def main():
7679 opensearch_vector_search_client = create_opensearch_vector_search_client (index_name , opensearch_password , bedrock_embeddings_client , opensearch_endpoint )
7780
7881 # LangChain prompt template
79- if len (args .ask ) > 0 :
80- question = args .ask
81- else :
82- question = "what is the meaning of <3?"
83- logger .info (f"No question provided, using default question { question } " )
84-
85- prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
82+ prompt = ChatPromptTemplate .from_template ("""Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
8683
8784 {context}
8885
89- Question: {question}
90- Answer:"""
91- PROMPT = PromptTemplate (
92- template = prompt_template , input_variables = ["context" , "question" ]
93- )
86+ Question: {input}
87+ Answer:""" )
9488
95- logger .info (f"Starting the chain with KNN similarity using OpenSearch, Bedrock FM { bedrock_model_id } , and Bedrock embeddings with { bedrock_embedding_model_id } " )
96- qa = RetrievalQA .from_chain_type (llm = bedrock_llm ,
97- chain_type = "stuff" ,
98- retriever = opensearch_vector_search_client .as_retriever (),
99- return_source_documents = True ,
100- chain_type_kwargs = {"prompt" : PROMPT , "verbose" : True },
101- verbose = True )
89+ docs_chain = create_stuff_documents_chain (bedrock_llm , prompt )
90+ retrieval_chain = create_retrieval_chain (
91+ retriever = opensearch_vector_search_client .as_retriever (),
92+ combine_docs_chain = docs_chain
93+ )
10294
103- response = qa .invoke (question , return_only_outputs = False )
95+ logger .info (f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM { bedrock_model_id } , and Bedrock embeddings with { bedrock_embedding_model_id } " )
96+ response = retrieval_chain .invoke ({"input" : question })
10497
105- logger .info ("This are the similar documents from OpenSearch based on the provided query" )
106- source_documents = response .get ('source_documents' )
98+ print ("" )
99+ logger .info ("These are the similar documents from OpenSearch based on the provided query:" )
100+ source_documents = response .get ('context' )
107101 for d in source_documents :
108- logger . info ( f"With the following similar content from OpenSearch: \n { d . page_content } \n " )
109- logger .info (f"Text: { d .metadata [ 'text' ] } " )
102+ print ( " " )
103+ logger .info (f"Text: { d .page_content } " )
110104
111- logger .info (f"The answer from Bedrock { bedrock_model_id } is: { response .get ('result' )} " )
105+ print ("" )
106+ logger .info (f"The answer from Bedrock { bedrock_model_id } is: { response .get ('answer' )} " )
112107
113108
114109if __name__ == "__main__" :
0 commit comments