Skip to content

Commit 96d37ea

Browse files
authored
Merge pull request #12 from aws-samples/multi-tenant
Multi tenant
2 parents 0b645e0 + 5c0e6c8 commit 96d37ea

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ AI21 Labs:
6161
>>- `--early-stop` to load only 100 embedded documents into OpenSearch
6262
>>- `--index` to use a different index than the default **rag**
6363
>>- `--region` in case you are not using the default **us-east-1**
64+
>>- `--multi-tenant` to use multi tenancy, will load data with tenant IDs (1-5)
65+
6466
6567
3. Now that we have embedded text, into our OpenSearch cluster, we can start querying our LLM model Titan text in Amazon Bedrock with RAG
6668
@@ -72,6 +74,8 @@ AI21 Labs:
7274
>>- `--index` to use a different index than the default **rag**
7375
>>- `--region` in case you are not using the default **us-east-1**
7476
>>- `--bedrock-model-id` to choose different models than Anthropic's Claude v2
77+
>>- `--tenant-id` to filter only a specific tenant ID
78+
7579

7680
### Cleanup
7781

ask-bedrock-with-rag.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def parse_args():
2222
parser.add_argument("--ask", type=str, default="What is the meaning of <3?")
2323
parser.add_argument("--index", type=str, default="rag")
2424
parser.add_argument("--region", type=str, default="us-east-1")
25+
parser.add_argument("--tenant-id", type=str, default=None)
2526
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-3-sonnet-20240229-v1:0")
2627
parser.add_argument("--bedrock-embedding-model-id", type=str, default="amazon.titan-embed-text-v1")
2728

@@ -68,6 +69,7 @@ def main():
6869
bedrock_model_id = args.bedrock_model_id
6970
bedrock_embedding_model_id = args.bedrock_embedding_model_id
7071
question = args.ask
72+
tenant_id = args.tenant_id
7173
logger.info(f"Question provided: {question}")
7274

7375
# Creating all clients for chain
@@ -87,23 +89,30 @@ def main():
8789
Answer:""")
8890

8991
docs_chain = create_stuff_documents_chain(bedrock_llm, prompt)
92+
93+
search_kwargs = {}
94+
if tenant_id:
95+
search_kwargs["filter"] = {
96+
"term": {
97+
"tenant_id": tenant_id
98+
}
99+
}
100+
90101
retrieval_chain = create_retrieval_chain(
91-
retriever=opensearch_vector_search_client.as_retriever(),
102+
retriever=opensearch_vector_search_client.as_retriever(search_kwargs=search_kwargs),
92103
combine_docs_chain = docs_chain
93104
)
94105

95106
logger.info(f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
96107
response = retrieval_chain.invoke({"input": question})
97108

98-
print("")
99109
logger.info("These are the similar documents from OpenSearch based on the provided query:")
100110
source_documents = response.get('context')
101111
for d in source_documents:
102-
print("")
103112
logger.info(f"Text: {d.page_content}")
104113

105114
print("")
106-
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('answer')}")
115+
logger.info(f"The answer from Bedrock!!!!! {bedrock_model_id} is: {response.get('answer')}")
107116

108117

109118
if __name__ == "__main__":

load-data-to-opensearch.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from loguru import logger
66
import sys
77
import os
8+
import random
9+
810

911

1012
# logger
@@ -18,6 +20,7 @@ def parse_args():
1820
parser.add_argument("--early-stop", type=bool, default=0)
1921
parser.add_argument("--index", type=str, default="rag")
2022
parser.add_argument("--region", type=str, default="us-east-1")
23+
parser.add_argument("--multi-tenant", type=bool, default=0)
2124

2225
return parser.parse_known_args()
2326

@@ -33,16 +36,28 @@ def create_vector_embedding_with_bedrock(text, name, bedrock_client):
3336
modelId = "amazon.titan-embed-text-v1"
3437
accept = "application/json"
3538
contentType = "application/json"
39+
args, _ = parse_args()
40+
multi_tenant = args.multi_tenant
3641

3742
response = bedrock_client.invoke_model(
3843
body=body, modelId=modelId, accept=accept, contentType=contentType
3944
)
4045
response_body = json.loads(response.get("body").read())
4146

4247
embedding = response_body.get("embedding")
43-
return {"_index": name, "text": text, "vector_field": embedding}
4448

45-
49+
document = {
50+
"_index": name,
51+
"text": text,
52+
"vector_field": embedding
53+
}
54+
55+
56+
if multi_tenant == 1:
57+
document["tenant_id"] = random.randint(1, 5)
58+
59+
return document
60+
4661
def main():
4762
logger.info("Starting")
4863

@@ -52,9 +67,13 @@ def main():
5267
args, _ = parse_args()
5368
region = args.region
5469
name = args.index
70+
multi_tenant = args.multi_tenant
71+
5572

5673
# Prepare OpenSearch index with vector embeddings index mapping
57-
logger.info(f"recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
74+
logger.info(f"Recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
75+
if multi_tenant:
76+
logger.info("Using multi tenant mode")
5877
logger.info("Preparing OpenSearch Index")
5978
opensearch_password = secret.get_secret(name, region)
6079
opensearch_client = opensearch.get_opensearch_cluster_client(name, opensearch_password, region)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
boto3>=1.34.79
1+
boto3>=1.35.73
22
langchain==0.1.14
3-
langchain-community>=0.0.31
3+
langchain-community==0.0.36
44
langchain-core==0.1.50
55
coloredlogs>=15.0.1
66
jq==1.7.0

0 commit comments

Comments
 (0)