Skip to content

Commit e3d7e1d

Browse files
author
Omer Haim
authored
Merge pull request #8 from aws-samples/may-updates
May updates
2 parents a3e26bb + 3fd7a33 commit e3d7e1d

File tree

2 files changed

+28
-32
lines changed

2 files changed

+28
-32
lines changed

ask-bedrock-with-rag.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from utils import opensearch, secret
33
from langchain_community.embeddings import BedrockEmbeddings
44
from langchain_community.vectorstores import OpenSearchVectorSearch
5-
from langchain.chains import RetrievalQA
6-
from langchain.prompts import PromptTemplate
7-
from langchain.llms.bedrock import Bedrock
5+
from langchain.chains.combine_documents import create_stuff_documents_chain
6+
from langchain.chains import create_retrieval_chain
7+
from langchain_core.prompts import ChatPromptTemplate
8+
from langchain_community.chat_models import BedrockChat
89
import boto3
910
from loguru import logger
1011
import sys
@@ -18,10 +19,10 @@
1819

1920
def parse_args():
2021
parser = argparse.ArgumentParser()
21-
parser.add_argument("--ask", type=str, default="What is <3?")
22+
parser.add_argument("--ask", type=str, default="What is the meaning of <3?")
2223
parser.add_argument("--index", type=str, default="rag")
2324
parser.add_argument("--region", type=str, default="us-east-1")
24-
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-v2")
25+
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-3-sonnet-20240229-v1:0")
2526
parser.add_argument("--bedrock-embedding-model-id", type=str, default="amazon.titan-embed-text-v1")
2627

2728
return parser.parse_known_args()
@@ -51,7 +52,7 @@ def create_opensearch_vector_search_client(index_name, opensearch_password, bedr
5152

5253

5354
def create_bedrock_llm(bedrock_client, model_version_id):
54-
bedrock_llm = Bedrock(
55+
bedrock_llm = BedrockChat(
5556
model_id=model_version_id,
5657
client=bedrock_client,
5758
model_kwargs={'temperature': 0}
@@ -60,12 +61,14 @@ def create_bedrock_llm(bedrock_client, model_version_id):
6061

6162

6263
def main():
63-
logger.info("Starting")
64+
logger.info("Starting...")
6465
args, _ = parse_args()
6566
region = args.region
6667
index_name = args.index
6768
bedrock_model_id = args.bedrock_model_id
6869
bedrock_embedding_model_id = args.bedrock_embedding_model_id
70+
question = args.ask
71+
logger.info(f"Question provided: {question}")
6972

7073
# Creating all clients for chain
7174
bedrock_client = get_bedrock_client(region)
@@ -76,39 +79,31 @@ def main():
7679
opensearch_vector_search_client = create_opensearch_vector_search_client(index_name, opensearch_password, bedrock_embeddings_client, opensearch_endpoint)
7780

7881
# LangChain prompt template
79-
if len(args.ask) > 0:
80-
question = args.ask
81-
else:
82-
question = "what is the meaning of <3?"
83-
logger.info(f"No question provided, using default question {question}")
84-
85-
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
82+
prompt = ChatPromptTemplate.from_template("""Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
8683
8784
{context}
8885
89-
Question: {question}
90-
Answer:"""
91-
PROMPT = PromptTemplate(
92-
template=prompt_template, input_variables=["context", "question"]
93-
)
86+
Question: {input}
87+
Answer:""")
9488

95-
logger.info(f"Starting the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
96-
qa = RetrievalQA.from_chain_type(llm=bedrock_llm,
97-
chain_type="stuff",
98-
retriever=opensearch_vector_search_client.as_retriever(),
99-
return_source_documents=True,
100-
chain_type_kwargs={"prompt": PROMPT, "verbose": True},
101-
verbose=True)
89+
docs_chain = create_stuff_documents_chain(bedrock_llm, prompt)
90+
retrieval_chain = create_retrieval_chain(
91+
retriever=opensearch_vector_search_client.as_retriever(),
92+
combine_docs_chain = docs_chain
93+
)
10294

103-
response = qa.invoke(question, return_only_outputs=False)
95+
logger.info(f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
96+
response = retrieval_chain.invoke({"input": question})
10497

105-
logger.info("This are the similar documents from OpenSearch based on the provided query")
106-
source_documents = response.get('source_documents')
98+
print("")
99+
logger.info("These are the similar documents from OpenSearch based on the provided query:")
100+
source_documents = response.get('context')
107101
for d in source_documents:
108-
logger.info(f"With the following similar content from OpenSearch:\n{d.page_content}\n")
109-
logger.info(f"Text: {d.metadata['text']}")
102+
print("")
103+
logger.info(f"Text: {d.page_content}")
110104

111-
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('result')}")
105+
print("")
106+
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('answer')}")
112107

113108

114109
if __name__ == "__main__":

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
boto3>=1.34.79
22
langchain==0.1.14
33
langchain-community==0.0.31
4+
langchain-core==0.1.50
45
coloredlogs>=15.0.1
56
jq==1.7.0
67
opensearch-py==2.5.0

0 commit comments

Comments
 (0)