Skip to content

Commit f6c7aa8

Browse files
committed
add multi tenant support
1 parent 0b645e0 commit f6c7aa8

File tree

3 files changed

+44
-15
lines changed

3 files changed

+44
-15
lines changed

ask-bedrock-with-rag.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def parse_args():
2222
parser.add_argument("--ask", type=str, default="What is the meaning of <3?")
2323
parser.add_argument("--index", type=str, default="rag")
2424
parser.add_argument("--region", type=str, default="us-east-1")
25+
parser.add_argument("--tenant-id", type=str, default=None)
2526
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-3-sonnet-20240229-v1:0")
2627
parser.add_argument("--bedrock-embedding-model-id", type=str, default="amazon.titan-embed-text-v1")
2728

@@ -68,6 +69,7 @@ def main():
6869
bedrock_model_id = args.bedrock_model_id
6970
bedrock_embedding_model_id = args.bedrock_embedding_model_id
7071
question = args.ask
72+
tenant_id = args.tenant_id
7173
logger.info(f"Question provided: {question}")
7274

7375
# Creating all clients for chain
@@ -87,23 +89,31 @@ def main():
8789
Answer:""")
8890

8991
docs_chain = create_stuff_documents_chain(bedrock_llm, prompt)
92+
93+
search_kwargs = {}
94+
if tenant_id:
95+
search_kwargs["filter"] = {
96+
"term": {
97+
"tenant_id": tenant_id
98+
}
99+
}
100+
90101
retrieval_chain = create_retrieval_chain(
91-
retriever=opensearch_vector_search_client.as_retriever(),
102+
retriever=opensearch_vector_search_client.as_retriever(search_kwargs=search_kwargs),
92103
combine_docs_chain = docs_chain
93104
)
94105

95106
logger.info(f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
96107
response = retrieval_chain.invoke({"input": question})
97108

98-
print("")
99109
logger.info("These are the similar documents from OpenSearch based on the provided query:")
100110
source_documents = response.get('context')
101111
for d in source_documents:
102-
print("")
112+
print (f"tenant_id={tenant_id}")
103113
logger.info(f"Text: {d.page_content}")
104114

105115
print("")
106-
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('answer')}")
116+
logger.info(f"The answer from Bedrock!!!!! {bedrock_model_id} is: {response.get('answer')}")
107117

108118

109119
if __name__ == "__main__":

load-data-to-opensearch.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from loguru import logger
66
import sys
77
import os
8+
import random
9+
810

911

1012
# logger
@@ -18,6 +20,7 @@ def parse_args():
1820
parser.add_argument("--early-stop", type=bool, default=0)
1921
parser.add_argument("--index", type=str, default="rag")
2022
parser.add_argument("--region", type=str, default="us-east-1")
23+
parser.add_argument("--multi-tenant", type=bool, default=0)
2124

2225
return parser.parse_known_args()
2326

@@ -33,16 +36,28 @@ def create_vector_embedding_with_bedrock(text, name, bedrock_client):
3336
modelId = "amazon.titan-embed-text-v1"
3437
accept = "application/json"
3538
contentType = "application/json"
39+
args, _ = parse_args()
40+
multi_tenant = args.multi_tenant
3641

3742
response = bedrock_client.invoke_model(
3843
body=body, modelId=modelId, accept=accept, contentType=contentType
3944
)
4045
response_body = json.loads(response.get("body").read())
4146

4247
embedding = response_body.get("embedding")
43-
return {"_index": name, "text": text, "vector_field": embedding}
4448

45-
49+
document = {
50+
"_index": name,
51+
"text": text,
52+
"vector_field": embedding
53+
}
54+
55+
56+
if multi_tenant == 1:
57+
document["tenant_id"] = random.randint(1, 5)
58+
59+
return document
60+
4661
def main():
4762
logger.info("Starting")
4863

@@ -52,9 +67,13 @@ def main():
5267
args, _ = parse_args()
5368
region = args.region
5469
name = args.index
70+
multi_tenant = args.multi_tenant
71+
5572

5673
# Prepare OpenSearch index with vector embeddings index mapping
57-
logger.info(f"recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
74+
logger.info(f"Recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
75+
if multi_tenant:
76+
logger.info("Using multi tenant mode")
5877
logger.info("Preparing OpenSearch Index")
5978
opensearch_password = secret.get_secret(name, region)
6079
opensearch_client = opensearch.get_opensearch_cluster_client(name, opensearch_password, region)

requirements.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
boto3>=1.34.79
2-
langchain==0.1.14
3-
langchain-community>=0.0.31
4-
langchain-core==0.1.50
5-
coloredlogs>=15.0.1
6-
jq==1.7.0
7-
opensearch-py==2.5.0
8-
loguru==0.7.2
1+
boto3
2+
langchain
3+
langchain-community
4+
langchain-core
5+
coloredlogs
6+
jq
7+
opensearch-py
8+
loguru

0 commit comments

Comments
 (0)