Skip to content

Commit a3e26bb

Browse files
author
omerh
committed
update versions, using langchain community, added loguru
1 parent bc0e0ed commit a3e26bb

File tree

6 files changed

+75
-57
lines changed

6 files changed

+75
-57
lines changed

ask-bedrock-with-rag.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
import coloredlogs
2-
import logging
31
import argparse
42
from utils import opensearch, secret
5-
from langchain.embeddings import BedrockEmbeddings
6-
from langchain.vectorstores import OpenSearchVectorSearch
3+
from langchain_community.embeddings import BedrockEmbeddings
4+
from langchain_community.vectorstores import OpenSearchVectorSearch
75
from langchain.chains import RetrievalQA
86
from langchain.prompts import PromptTemplate
97
from langchain.llms.bedrock import Bedrock
108
import boto3
9+
from loguru import logger
10+
import sys
11+
import os
1112

1213

13-
coloredlogs.install(fmt='%(asctime)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level='INFO')
14-
logging.basicConfig(level=logging.INFO)
15-
logger = logging.getLogger(__name__)
14+
# logger
15+
logger.remove()
16+
logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO"))
17+
1618

1719
def parse_args():
1820
parser = argparse.ArgumentParser()
@@ -58,7 +60,7 @@ def create_bedrock_llm(bedrock_client, model_version_id):
5860

5961

6062
def main():
61-
logging.info("Starting")
63+
logger.info("Starting")
6264
args, _ = parse_args()
6365
region = args.region
6466
index_name = args.index
@@ -78,7 +80,7 @@ def main():
7880
question = args.ask
7981
else:
8082
question = "what is the meaning of <3?"
81-
logging.info(f"No question provided, using default question {question}")
83+
logger.info(f"No question provided, using default question {question}")
8284

8385
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
8486
@@ -90,23 +92,23 @@ def main():
9092
template=prompt_template, input_variables=["context", "question"]
9193
)
9294

93-
logging.info(f"Starting the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
95+
logger.info(f"Starting the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
9496
qa = RetrievalQA.from_chain_type(llm=bedrock_llm,
9597
chain_type="stuff",
9698
retriever=opensearch_vector_search_client.as_retriever(),
9799
return_source_documents=True,
98100
chain_type_kwargs={"prompt": PROMPT, "verbose": True},
99101
verbose=True)
100102

101-
response = qa(question, return_only_outputs=False)
103+
response = qa.invoke(question, return_only_outputs=False)
102104

103-
logging.info("This are the similar documents from OpenSearch based on the provided query")
105+
logger.info("This are the similar documents from OpenSearch based on the provided query")
104106
source_documents = response.get('source_documents')
105107
for d in source_documents:
106-
logging.info(f"With the following similar content from OpenSearch:\n{d.page_content}\n")
107-
logging.info(f"Text: {d.metadata['text']}")
108+
logger.info(f"With the following similar content from OpenSearch:\n{d.page_content}\n")
109+
logger.info(f"Text: {d.metadata['text']}")
108110

109-
logging.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('result')}")
111+
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('result')}")
110112

111113

112114
if __name__ == "__main__":

load-data-to-opensearch.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
import logging
2-
import coloredlogs
31
import json
42
import argparse
53
import boto3
64
from utils import dataset, secret, opensearch
5+
from loguru import logger
6+
import sys
7+
import os
78

8-
coloredlogs.install(fmt='%(asctime)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level='INFO')
9-
logging.basicConfig(level=logging.INFO)
10-
logger = logging.getLogger(__name__)
9+
10+
# logger
11+
logger.remove()
12+
logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO"))
1113

1214

1315
def parse_args():
@@ -42,7 +44,7 @@ def create_vector_embedding_with_bedrock(text, name, bedrock_client):
4244

4345

4446
def main():
45-
logging.info("Starting")
47+
logger.info("Starting")
4648

4749
dataset_url = "https://huggingface.co/datasets/sentence-transformers/embedding-training-data/resolve/main/gooaq_pairs.jsonl.gz"
4850
early_stop_record_count = 100
@@ -52,29 +54,29 @@ def main():
5254
name = args.index
5355

5456
# Prepare OpenSearch index with vector embeddings index mapping
55-
logging.info(f"recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
56-
logging.info("Preparing OpenSearch Index")
57+
logger.info(f"recreating opensearch index: {args.recreate}, using early stop: {args.early_stop} to insert only {early_stop_record_count} records")
58+
logger.info("Preparing OpenSearch Index")
5759
opensearch_password = secret.get_secret(name, region)
5860
opensearch_client = opensearch.get_opensearch_cluster_client(name, opensearch_password, region)
5961

6062
# Check if to delete OpenSearch index with the argument passed to the script --recreate 1
6163
if args.recreate:
6264
response = opensearch.delete_opensearch_index(opensearch_client, name)
6365
if response:
64-
logging.info("OpenSearch index successfully deleted")
66+
logger.info("OpenSearch index successfully deleted")
6567

66-
logging.info(f"Checking if index {name} exists in OpenSearch cluster")
68+
logger.info(f"Checking if index {name} exists in OpenSearch cluster")
6769
exists = opensearch.check_opensearch_index(opensearch_client, name)
6870
if not exists:
69-
logging.info("Creating OpenSearch index")
71+
logger.info("Creating OpenSearch index")
7072
success = opensearch.create_index(opensearch_client, name)
7173
if success:
72-
logging.info("Creating OpenSearch index mapping")
74+
logger.info("Creating OpenSearch index mapping")
7375
success = opensearch.create_index_mapping(opensearch_client, name)
74-
logging.info(f"OpenSearch Index mapping created")
76+
logger.info(f"OpenSearch Index mapping created")
7577

7678
# Download sample dataset from HuggingFace
77-
logging.info("Downloading dataset from HuggingFace")
79+
logger.info("Downloading dataset from HuggingFace")
7880
compressed_file_path = dataset.download_dataset(dataset_url)
7981
if compressed_file_path is not None:
8082
file_path = dataset.decompress_dataset(compressed_file_path)
@@ -86,7 +88,7 @@ def main():
8688

8789
# Vector embedding using Amazon Bedrock Titan text embedding
8890
all_json_records = []
89-
logging.info(f"Creating embeddings for records")
91+
logger.info(f"Creating embeddings for records")
9092

9193
# using the arg --early-stop
9294
i = 0
@@ -96,24 +98,24 @@ def main():
9698
if i > early_stop_record_count:
9799
# Bulk put all records to OpenSearch
98100
success, failed = opensearch.put_bulk_in_opensearch(all_json_records, opensearch_client)
99-
logging.info(f"Documents saved {success}, documents failed to save {failed}")
101+
logger.info(f"Documents saved {success}, documents failed to save {failed}")
100102
break
101103
records_with_embedding = create_vector_embedding_with_bedrock(record, name, bedrock_client)
102-
logging.info(f"Embedding for record {i} created")
104+
logger.info(f"Embedding for record {i} created")
103105
all_json_records.append(records_with_embedding)
104106
if i % 500 == 0 or i == len(all_records)-1:
105107
# Bulk put all records to OpenSearch
106108
success, failed = opensearch.put_bulk_in_opensearch(all_json_records, opensearch_client)
107109
all_json_records = []
108-
logging.info(f"Documents saved {success}, documents failed to save {failed}")
110+
logger.info(f"Documents saved {success}, documents failed to save {failed}")
109111

110-
logging.info("Finished creating records using Amazon Bedrock Titan text embedding")
112+
logger.info("Finished creating records using Amazon Bedrock Titan text embedding")
111113

112-
logging.info("Cleaning up")
114+
logger.info("Cleaning up")
113115
dataset.delete_file(compressed_file_path)
114116
dataset.delete_file(file_path)
115117

116-
logging.info("Finished")
118+
logger.info("Finished")
117119

118120
if __name__ == "__main__":
119121
main()

requirements.txt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
boto3>=1.28.57
2-
langchain>=0.0.312
3-
coloredlogs==15.0.1
4-
jq==1.6.0
5-
opensearch-py==2.3.1
1+
boto3>=1.34.79
2+
langchain==0.1.14
3+
langchain-community==0.0.31
4+
coloredlogs>=15.0.1
5+
jq==1.7.0
6+
opensearch-py==2.5.0
7+
loguru==0.7.2

utils/dataset.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,53 @@
1-
import logging
21
import requests
32
import gzip
43
import json
54
import tempfile
5+
from loguru import logger
6+
import sys
67
import os
78

89

10+
# logger
11+
logger.remove()
12+
logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO"))
13+
14+
915
def download_dataset(url):
1016
# download the dataset and store it to tmp dir
1117
try:
12-
logging.info("Downloading dataset")
18+
logger.info("Downloading dataset")
1319
response = requests.get(url)
1420
if response.status_code == 200:
1521
temp = tempfile.NamedTemporaryFile(delete=False)
1622
temp.write(response.content)
1723
temp.close()
1824
return temp.name
1925
else:
20-
logging.error("Failed to download dataset")
26+
logger.error("Failed to download dataset")
2127
return None
2228
except Exception as e:
23-
logging.error(e)
29+
logger.error(e)
2430
return None
2531

2632

2733

2834
def decompress_dataset(file_path):
2935
# decompress the dataset
3036
temp_fd, temp_path = tempfile.mkstemp()
31-
logging.info(f"Decompressing dataset {file_path} to new file {temp_path}")
37+
logger.info(f"Decompressing dataset {file_path} to new file {temp_path}")
3238
try:
3339
with gzip.open(file_path, 'rb') as compressed:
3440
with open(temp_path, 'wb') as decompressed:
3541
decompressed.write(compressed.read())
36-
logging.info("Decompression complete")
42+
logger.info("Decompression complete")
3743
return temp_path
3844
except Exception as e:
39-
logging.error(e)
45+
logger.error(e)
4046
return None
4147

4248

4349
def prep_for_put(file_path):
44-
logging.info(f"Loading file {file_path}")
50+
logger.info(f"Loading file {file_path}")
4551
all_records = []
4652
with open(file_path, 'r') as f:
4753
for line in f:
@@ -52,8 +58,8 @@ def prep_for_put(file_path):
5258

5359

5460
def delete_file(file_path):
55-
logging.info(f"Deleting file {file_path}")
61+
logger.info(f"Deleting file {file_path}")
5662
try:
5763
os.remove(file_path)
5864
except Exception as e:
59-
logging.error(e)
65+
logger.error(e)

utils/opensearch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import boto3
22
from opensearchpy import OpenSearch, RequestsHttpConnection
33
from opensearchpy.helpers import bulk
4-
import logging
4+
from loguru import logger
5+
import sys
6+
import os
7+
8+
# logger
9+
logger.remove()
10+
logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO"))
511

612

713
def get_opensearch_cluster_client(name, password, region):
@@ -30,7 +36,7 @@ def get_opensearch_endpoint(name, region):
3036

3137

3238
def put_bulk_in_opensearch(list, client):
33-
logging.info(f"Putting {len(list)} documents in OpenSearch")
39+
logger.info(f"Putting {len(list)} documents in OpenSearch")
3440
success, failed = bulk(client, list)
3541
return success, failed
3642

@@ -70,12 +76,12 @@ def create_index_mapping(opensearch_client, index_name):
7076

7177

7278
def delete_opensearch_index(opensearch_client, index_name):
73-
logging.info(f"Trying to delete index {index_name}")
79+
logger.info(f"Trying to delete index {index_name}")
7480
try:
7581
response = opensearch_client.indices.delete(index=index_name)
76-
logging.info(f"Index {index_name} deleted")
82+
logger.info(f"Index {index_name} deleted")
7783
return response['acknowledged']
7884
except Exception as e:
79-
logging.info(f"Index {index_name} not found, nothing to delete")
85+
logger.info(f"Index {index_name} not found, nothing to delete")
8086
return True
8187

utils/secret.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import boto3
2-
import logging
2+
33

44
def get_secret(secret_prefix, region):
55
client = boto3.client('secretsmanager', region_name=region)

0 commit comments

Comments
 (0)