Skip to content

Commit db78131

Browse files
committed
Initial commit for RAG pipeline scripts
Signed-off-by: hmumtazz <[email protected]>
1 parent 39284d6 commit db78131

File tree

8 files changed

+1099
-0
lines changed

8 files changed

+1099
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Ignore data and ingestion directories
2+
ml_commons/rag_pipeline/data/
3+
ml_commons/rag_pipeline/ingestion/
4+
ml_commons/rag_pipeline/rag/config.ini
5+
# Ignore virtual environment
6+
.venv/
7+
# Or, specify the full path
8+
/Users/hmumtazz/.cursor-tutor/opensearch-py-ml/.venv/
9+
10+
# Ignore Python cache files
11+
__pycache__/
12+
*.pyc
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# ingest_class.py
2+
3+
import os
4+
import glob
5+
import json
6+
import tiktoken
7+
from tqdm import tqdm
8+
from colorama import Fore, Style, init
9+
from typing import List, Dict
10+
import csv
11+
import PyPDF2
12+
import boto3
13+
import botocore
14+
import time
15+
import random
16+
17+
18+
from opensearch_class import OpenSearchClass
19+
20+
init(autoreset=True) # Initialize colorama
21+
22+
class IngestClass:
23+
EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1'
24+
25+
def __init__(self, config):
26+
self.config = config
27+
self.aws_region = config.get('region')
28+
self.index_name = config.get('index_name')
29+
self.bedrock_client = None
30+
self.opensearch = OpenSearchClass(config)
31+
32+
def initialize_clients(self):
33+
try:
34+
self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region)
35+
if self.opensearch.initialize_opensearch_client():
36+
print("Clients initialized successfully.")
37+
return True
38+
else:
39+
print("Failed to initialize OpenSearch client.")
40+
return False
41+
except Exception as e:
42+
print(f"Failed to initialize clients: {e}")
43+
return False
44+
45+
def process_file(self, file_path: str) -> List[Dict[str, str]]:
46+
_, file_extension = os.path.splitext(file_path)
47+
48+
if file_extension.lower() == '.csv':
49+
return self.process_csv(file_path)
50+
elif file_extension.lower() == '.txt':
51+
return self.process_txt(file_path)
52+
elif file_extension.lower() == '.pdf':
53+
return self.process_pdf(file_path)
54+
else:
55+
print(f"Unsupported file type: {file_extension}")
56+
return []
57+
58+
def process_csv(self, file_path: str) -> List[Dict[str, str]]:
59+
documents = []
60+
with open(file_path, 'r') as csvfile:
61+
reader = csv.DictReader(csvfile)
62+
for row in reader:
63+
text = f"{row['name']} got nominated under the category, {row['category']}, for the film {row['film']}"
64+
if row.get('winner', '').lower() != 'true':
65+
text += " but did not win"
66+
documents.append({"text": text})
67+
return documents
68+
69+
def process_txt(self, file_path: str) -> List[Dict[str, str]]:
70+
with open(file_path, 'r') as txtfile:
71+
content = txtfile.read()
72+
return [{"text": content}]
73+
74+
def process_pdf(self, file_path: str) -> List[Dict[str, str]]:
75+
documents = []
76+
with open(file_path, 'rb') as pdffile:
77+
pdf_reader = PyPDF2.PdfReader(pdffile)
78+
for page in pdf_reader.pages:
79+
extracted_text = page.extract_text()
80+
if extracted_text: # Ensure that text was extracted
81+
documents.append({"text": extracted_text})
82+
return documents
83+
84+
def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2):
85+
if self.bedrock_client is None:
86+
print("Bedrock client is not initialized. Please run setup first.")
87+
return None
88+
89+
delay = initial_delay
90+
for attempt in range(max_retries):
91+
try:
92+
payload = {"inputText": text}
93+
response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload))
94+
response_body = json.loads(response['body'].read())
95+
embedding = response_body.get('embedding')
96+
if embedding is None:
97+
print(f"No embedding returned for text: {text}")
98+
print(f"Response body: {response_body}")
99+
return None
100+
return embedding
101+
except botocore.exceptions.ClientError as e:
102+
error_code = e.response['Error']['Code']
103+
error_message = e.response['Error']['Message']
104+
print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}")
105+
if error_code == 'ThrottlingException':
106+
if attempt == max_retries - 1:
107+
raise
108+
time.sleep(delay + random.uniform(0, 1))
109+
delay *= backoff_factor
110+
else:
111+
raise
112+
except Exception as ex:
113+
print(f"Unexpected error on attempt {attempt + 1}: {ex}")
114+
if attempt == max_retries - 1:
115+
raise
116+
return None
117+
118+
def process_and_ingest_data(self, file_paths: List[str]):
119+
if not self.initialize_clients():
120+
print("Failed to initialize clients. Aborting ingestion.")
121+
return
122+
123+
all_documents = []
124+
for file_path in file_paths:
125+
print(f"Processing file: {file_path}")
126+
documents = self.process_file(file_path)
127+
all_documents.extend(documents)
128+
129+
total_documents = len(all_documents)
130+
print(f"Total documents to process: {total_documents}")
131+
132+
print("Generating embeddings for the documents...")
133+
success_count = 0
134+
error_count = 0
135+
with tqdm(total=total_documents, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar:
136+
for doc in all_documents:
137+
try:
138+
embedding = self.text_embedding(doc['text'])
139+
if embedding is not None:
140+
doc['embedding'] = embedding
141+
success_count += 1
142+
else:
143+
error_count += 1
144+
print(f"{Fore.RED}Error generating embedding for document: {doc['text'][:50]}...{Style.RESET_ALL}")
145+
except Exception as e:
146+
error_count += 1
147+
print(f"{Fore.RED}Error processing document: {str(e)}{Style.RESET_ALL}")
148+
pbar.update(1)
149+
pbar.set_postfix({'Success': success_count, 'Errors': error_count})
150+
151+
print(f"\n{Fore.GREEN}Documents with successful embeddings: {success_count}{Style.RESET_ALL}")
152+
print(f"{Fore.RED}Documents with failed embeddings: {error_count}{Style.RESET_ALL}")
153+
154+
if success_count == 0:
155+
print(f"{Fore.RED}No documents to ingest. Aborting ingestion.{Style.RESET_ALL}")
156+
return
157+
158+
print(f"{Fore.YELLOW}Ingesting data into OpenSearch...{Style.RESET_ALL}")
159+
actions = []
160+
for doc in all_documents:
161+
if 'embedding' in doc and doc['embedding'] is not None:
162+
action = {
163+
"_index": self.index_name,
164+
"_source": {
165+
"nominee_text": doc['text'],
166+
"nominee_vector": doc['embedding']
167+
}
168+
}
169+
actions.append(action)
170+
171+
success, failed = self.opensearch.bulk_index(actions)
172+
print(f"{Fore.GREEN}Successfully ingested {success} documents.{Style.RESET_ALL}")
173+
print(f"{Fore.RED}Failed to ingest {failed} documents.{Style.RESET_ALL}")
174+
175+
def ingest_command(self, paths: List[str]):
176+
all_files = []
177+
for path in paths:
178+
if os.path.isfile(path):
179+
all_files.append(path)
180+
elif os.path.isdir(path):
181+
all_files.extend(glob.glob(os.path.join(path, '*')))
182+
else:
183+
print(f"{Fore.YELLOW}Invalid path: {path}{Style.RESET_ALL}")
184+
185+
supported_extensions = ['.csv', '.txt', '.pdf']
186+
valid_files = [f for f in all_files if any(f.lower().endswith(ext) for ext in supported_extensions)]
187+
188+
if not valid_files:
189+
print(f"{Fore.RED}No valid files found for ingestion.{Style.RESET_ALL}")
190+
return
191+
192+
print(f"{Fore.GREEN}Found {len(valid_files)} valid files for ingestion.{Style.RESET_ALL}")
193+
194+
self.process_and_ingest_data(valid_files)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# opensearch_class.py
2+
3+
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, exceptions as opensearch_exceptions
4+
import boto3
5+
from urllib.parse import urlparse
6+
from opensearchpy import helpers as opensearch_helpers
7+
8+
class OpenSearchClass:
9+
def __init__(self, config):
10+
self.config = config
11+
self.opensearch_client = None
12+
self.aws_region = config.get('region')
13+
self.index_name = config.get('index_name')
14+
self.is_serverless = config.get('is_serverless', 'False') == 'True'
15+
self.opensearch_endpoint = config.get('opensearch_endpoint')
16+
self.opensearch_username = config.get('opensearch_username')
17+
self.opensearch_password = config.get('opensearch_password')
18+
19+
def initialize_opensearch_client(self):
20+
if not self.opensearch_endpoint:
21+
print("OpenSearch endpoint not set. Please run setup first.")
22+
return False
23+
24+
parsed_url = urlparse(self.opensearch_endpoint)
25+
host = parsed_url.hostname
26+
port = parsed_url.port or 443
27+
28+
if self.is_serverless:
29+
credentials = boto3.Session().get_credentials()
30+
auth = AWSV4SignerAuth(credentials, self.aws_region, 'aoss')
31+
else:
32+
if not self.opensearch_username or not self.opensearch_password:
33+
print("OpenSearch username or password not set. Please run setup first.")
34+
return False
35+
auth = (self.opensearch_username, self.opensearch_password)
36+
37+
try:
38+
self.opensearch_client = OpenSearch(
39+
hosts=[{'host': host, 'port': port}],
40+
http_auth=auth,
41+
use_ssl=True,
42+
verify_certs=True,
43+
connection_class=RequestsHttpConnection,
44+
pool_maxsize=20
45+
)
46+
print(f"Initialized OpenSearch client with host: {host} and port: {port}")
47+
return True
48+
except Exception as ex:
49+
print(f"Error initializing OpenSearch client: {ex}")
50+
return False
51+
52+
def create_index(self, embedding_dimension, space_type):
53+
index_body = {
54+
"mappings": {
55+
"properties": {
56+
"nominee_text": {"type": "text"},
57+
"nominee_vector": {
58+
"type": "knn_vector",
59+
"dimension": embedding_dimension,
60+
"method": {
61+
"name": "hnsw",
62+
"space_type": space_type,
63+
"engine": "nmslib",
64+
"parameters": {"ef_construction": 512, "m": 16},
65+
},
66+
},
67+
}
68+
},
69+
"settings": {
70+
"index": {
71+
"number_of_shards": 2,
72+
"knn.algo_param": {"ef_search": 512},
73+
"knn": True,
74+
}
75+
},
76+
}
77+
try:
78+
self.opensearch_client.indices.create(index=self.index_name, body=index_body)
79+
print(f"KNN index '{self.index_name}' created successfully with dimension {embedding_dimension} and space type {space_type}.")
80+
except opensearch_exceptions.RequestError as e:
81+
if 'resource_already_exists_exception' in str(e).lower():
82+
print(f"Index '{self.index_name}' already exists.")
83+
else:
84+
print(f"Error creating index '{self.index_name}': {e}")
85+
86+
def verify_and_create_index(self, embedding_dimension, space_type):
87+
try:
88+
index_exists = self.opensearch_client.indices.exists(index=self.index_name)
89+
if index_exists:
90+
print(f"KNN index '{self.index_name}' already exists.")
91+
else:
92+
self.create_index(embedding_dimension, space_type)
93+
return True
94+
except Exception as ex:
95+
print(f"Error verifying or creating index: {ex}")
96+
return False
97+
98+
def bulk_index(self, actions):
99+
try:
100+
success, failed = opensearch_helpers.bulk(self.opensearch_client, actions)
101+
print(f"Indexed {success} documents successfully. Failed to index {failed} documents.")
102+
return success, failed
103+
except Exception as e:
104+
print(f"Error during bulk indexing: {e}")
105+
return 0, len(actions)
106+
107+
def search(self, vector, k=5):
108+
try:
109+
response = self.opensearch_client.search(
110+
index=self.index_name,
111+
body={
112+
"size": k,
113+
"_source": ["nominee_text"],
114+
"query": {
115+
"knn": {
116+
"nominee_vector": {
117+
"vector": vector,
118+
"k": k
119+
}
120+
}
121+
}
122+
}
123+
)
124+
return response['hits']['hits']
125+
except Exception as e:
126+
print(f"Error during search: {e}")
127+
return []

0 commit comments

Comments
 (0)