@@ -22,6 +22,7 @@ def parse_args():
2222 parser .add_argument ("--ask" , type = str , default = "What is the meaning of <3?" )
2323 parser .add_argument ("--index" , type = str , default = "rag" )
2424 parser .add_argument ("--region" , type = str , default = "us-east-1" )
25+ parser .add_argument ("--tenant-id" , type = str , default = None )
2526 parser .add_argument ("--bedrock-model-id" , type = str , default = "anthropic.claude-3-sonnet-20240229-v1:0" )
2627 parser .add_argument ("--bedrock-embedding-model-id" , type = str , default = "amazon.titan-embed-text-v1" )
2728
@@ -68,6 +69,7 @@ def main():
6869 bedrock_model_id = args .bedrock_model_id
6970 bedrock_embedding_model_id = args .bedrock_embedding_model_id
7071 question = args .ask
72+ tenant_id = args .tenant_id
7173 logger .info (f"Question provided: { question } " )
7274
7375 # Creating all clients for chain
@@ -87,23 +89,31 @@ def main():
8789 Answer:""" )
8890
8991 docs_chain = create_stuff_documents_chain (bedrock_llm , prompt )
92+
93+ search_kwargs = {}
94+ if tenant_id :
95+ search_kwargs ["filter" ] = {
96+ "term" : {
97+ "tenant_id" : tenant_id
98+ }
99+ }
100+
90101 retrieval_chain = create_retrieval_chain (
91- retriever = opensearch_vector_search_client .as_retriever (),
102+ retriever = opensearch_vector_search_client .as_retriever (search_kwargs = search_kwargs ),
92103 combine_docs_chain = docs_chain
93104 )
94105
95106 logger .info (f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM { bedrock_model_id } , and Bedrock embeddings with { bedrock_embedding_model_id } " )
96107 response = retrieval_chain .invoke ({"input" : question })
97108
98- print ("" )
99109 logger .info ("These are the similar documents from OpenSearch based on the provided query:" )
100110 source_documents = response .get ('context' )
101111 for d in source_documents :
102- print ( " " )
112+ print ( f"tenant_id= { tenant_id } " )
103113 logger .info (f"Text: { d .page_content } " )
104114
105115 print ("" )
106- logger .info (f"The answer from Bedrock { bedrock_model_id } is: { response .get ('answer' )} " )
116+ logger .info (f"The answer from Bedrock!!!!! { bedrock_model_id } is: { response .get ('answer' )} " )
107117
108118
109119if __name__ == "__main__" :
0 commit comments