|  | 
|  | 1 | +# ingest_class.py | 
|  | 2 | + | 
|  | 3 | +import os | 
|  | 4 | +import glob | 
|  | 5 | +import json | 
|  | 6 | +import tiktoken | 
|  | 7 | +from tqdm import tqdm | 
|  | 8 | +from colorama import Fore, Style, init | 
|  | 9 | +from typing import List, Dict | 
|  | 10 | +import csv | 
|  | 11 | +import PyPDF2 | 
|  | 12 | +import boto3 | 
|  | 13 | +import botocore | 
|  | 14 | +import time | 
|  | 15 | +import random | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +from opensearch_class import OpenSearchClass | 
|  | 19 | + | 
|  | 20 | +init(autoreset=True)  # Initialize colorama | 
|  | 21 | + | 
|  | 22 | +class IngestClass: | 
|  | 23 | +    EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1' | 
|  | 24 | + | 
|  | 25 | +    def __init__(self, config): | 
|  | 26 | +        self.config = config | 
|  | 27 | +        self.aws_region = config.get('region') | 
|  | 28 | +        self.index_name = config.get('index_name') | 
|  | 29 | +        self.bedrock_client = None | 
|  | 30 | +        self.opensearch = OpenSearchClass(config) | 
|  | 31 | + | 
|  | 32 | +    def initialize_clients(self): | 
|  | 33 | +        try: | 
|  | 34 | +            self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region) | 
|  | 35 | +            if self.opensearch.initialize_opensearch_client(): | 
|  | 36 | +                print("Clients initialized successfully.") | 
|  | 37 | +                return True | 
|  | 38 | +            else: | 
|  | 39 | +                print("Failed to initialize OpenSearch client.") | 
|  | 40 | +                return False | 
|  | 41 | +        except Exception as e: | 
|  | 42 | +            print(f"Failed to initialize clients: {e}") | 
|  | 43 | +            return False | 
|  | 44 | + | 
|  | 45 | +    def process_file(self, file_path: str) -> List[Dict[str, str]]: | 
|  | 46 | +        _, file_extension = os.path.splitext(file_path) | 
|  | 47 | +         | 
|  | 48 | +        if file_extension.lower() == '.csv': | 
|  | 49 | +            return self.process_csv(file_path) | 
|  | 50 | +        elif file_extension.lower() == '.txt': | 
|  | 51 | +            return self.process_txt(file_path) | 
|  | 52 | +        elif file_extension.lower() == '.pdf': | 
|  | 53 | +            return self.process_pdf(file_path) | 
|  | 54 | +        else: | 
|  | 55 | +            print(f"Unsupported file type: {file_extension}") | 
|  | 56 | +            return [] | 
|  | 57 | + | 
|  | 58 | +    def process_csv(self, file_path: str) -> List[Dict[str, str]]: | 
|  | 59 | +        documents = [] | 
|  | 60 | +        with open(file_path, 'r') as csvfile: | 
|  | 61 | +            reader = csv.DictReader(csvfile) | 
|  | 62 | +            for row in reader: | 
|  | 63 | +                text = f"{row['name']} got nominated under the category, {row['category']}, for the film {row['film']}" | 
|  | 64 | +                if row.get('winner', '').lower() != 'true': | 
|  | 65 | +                    text += " but did not win" | 
|  | 66 | +                documents.append({"text": text}) | 
|  | 67 | +        return documents | 
|  | 68 | + | 
|  | 69 | +    def process_txt(self, file_path: str) -> List[Dict[str, str]]: | 
|  | 70 | +        with open(file_path, 'r') as txtfile: | 
|  | 71 | +            content = txtfile.read() | 
|  | 72 | +        return [{"text": content}] | 
|  | 73 | + | 
|  | 74 | +    def process_pdf(self, file_path: str) -> List[Dict[str, str]]: | 
|  | 75 | +        documents = [] | 
|  | 76 | +        with open(file_path, 'rb') as pdffile: | 
|  | 77 | +            pdf_reader = PyPDF2.PdfReader(pdffile) | 
|  | 78 | +            for page in pdf_reader.pages: | 
|  | 79 | +                extracted_text = page.extract_text() | 
|  | 80 | +                if extracted_text:  # Ensure that text was extracted | 
|  | 81 | +                    documents.append({"text": extracted_text}) | 
|  | 82 | +        return documents | 
|  | 83 | + | 
|  | 84 | +    def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2): | 
|  | 85 | +        if self.bedrock_client is None: | 
|  | 86 | +            print("Bedrock client is not initialized. Please run setup first.") | 
|  | 87 | +            return None | 
|  | 88 | +         | 
|  | 89 | +        delay = initial_delay | 
|  | 90 | +        for attempt in range(max_retries): | 
|  | 91 | +            try: | 
|  | 92 | +                payload = {"inputText": text} | 
|  | 93 | +                response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload)) | 
|  | 94 | +                response_body = json.loads(response['body'].read()) | 
|  | 95 | +                embedding = response_body.get('embedding') | 
|  | 96 | +                if embedding is None: | 
|  | 97 | +                    print(f"No embedding returned for text: {text}") | 
|  | 98 | +                    print(f"Response body: {response_body}") | 
|  | 99 | +                    return None | 
|  | 100 | +                return embedding | 
|  | 101 | +            except botocore.exceptions.ClientError as e: | 
|  | 102 | +                error_code = e.response['Error']['Code'] | 
|  | 103 | +                error_message = e.response['Error']['Message'] | 
|  | 104 | +                print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}") | 
|  | 105 | +                if error_code == 'ThrottlingException': | 
|  | 106 | +                    if attempt == max_retries - 1: | 
|  | 107 | +                        raise | 
|  | 108 | +                    time.sleep(delay + random.uniform(0, 1)) | 
|  | 109 | +                    delay *= backoff_factor | 
|  | 110 | +                else: | 
|  | 111 | +                    raise | 
|  | 112 | +            except Exception as ex: | 
|  | 113 | +                print(f"Unexpected error on attempt {attempt + 1}: {ex}") | 
|  | 114 | +                if attempt == max_retries - 1: | 
|  | 115 | +                    raise | 
|  | 116 | +        return None | 
|  | 117 | + | 
|  | 118 | +    def process_and_ingest_data(self, file_paths: List[str]): | 
|  | 119 | +        if not self.initialize_clients(): | 
|  | 120 | +            print("Failed to initialize clients. Aborting ingestion.") | 
|  | 121 | +            return | 
|  | 122 | + | 
|  | 123 | +        all_documents = [] | 
|  | 124 | +        for file_path in file_paths: | 
|  | 125 | +            print(f"Processing file: {file_path}") | 
|  | 126 | +            documents = self.process_file(file_path) | 
|  | 127 | +            all_documents.extend(documents) | 
|  | 128 | +         | 
|  | 129 | +        total_documents = len(all_documents) | 
|  | 130 | +        print(f"Total documents to process: {total_documents}") | 
|  | 131 | +         | 
|  | 132 | +        print("Generating embeddings for the documents...") | 
|  | 133 | +        success_count = 0 | 
|  | 134 | +        error_count = 0 | 
|  | 135 | +        with tqdm(total=total_documents, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar: | 
|  | 136 | +            for doc in all_documents: | 
|  | 137 | +                try: | 
|  | 138 | +                    embedding = self.text_embedding(doc['text']) | 
|  | 139 | +                    if embedding is not None: | 
|  | 140 | +                        doc['embedding'] = embedding | 
|  | 141 | +                        success_count += 1 | 
|  | 142 | +                    else: | 
|  | 143 | +                        error_count += 1 | 
|  | 144 | +                        print(f"{Fore.RED}Error generating embedding for document: {doc['text'][:50]}...{Style.RESET_ALL}") | 
|  | 145 | +                except Exception as e: | 
|  | 146 | +                    error_count += 1 | 
|  | 147 | +                    print(f"{Fore.RED}Error processing document: {str(e)}{Style.RESET_ALL}") | 
|  | 148 | +                pbar.update(1) | 
|  | 149 | +                pbar.set_postfix({'Success': success_count, 'Errors': error_count}) | 
|  | 150 | +         | 
|  | 151 | +        print(f"\n{Fore.GREEN}Documents with successful embeddings: {success_count}{Style.RESET_ALL}") | 
|  | 152 | +        print(f"{Fore.RED}Documents with failed embeddings: {error_count}{Style.RESET_ALL}") | 
|  | 153 | +         | 
|  | 154 | +        if success_count == 0: | 
|  | 155 | +            print(f"{Fore.RED}No documents to ingest. Aborting ingestion.{Style.RESET_ALL}") | 
|  | 156 | +            return | 
|  | 157 | +         | 
|  | 158 | +        print(f"{Fore.YELLOW}Ingesting data into OpenSearch...{Style.RESET_ALL}") | 
|  | 159 | +        actions = [] | 
|  | 160 | +        for doc in all_documents: | 
|  | 161 | +            if 'embedding' in doc and doc['embedding'] is not None: | 
|  | 162 | +                action = { | 
|  | 163 | +                    "_index": self.index_name, | 
|  | 164 | +                    "_source": { | 
|  | 165 | +                        "nominee_text": doc['text'], | 
|  | 166 | +                        "nominee_vector": doc['embedding'] | 
|  | 167 | +                    } | 
|  | 168 | +                } | 
|  | 169 | +                actions.append(action) | 
|  | 170 | +         | 
|  | 171 | +        success, failed = self.opensearch.bulk_index(actions) | 
|  | 172 | +        print(f"{Fore.GREEN}Successfully ingested {success} documents.{Style.RESET_ALL}") | 
|  | 173 | +        print(f"{Fore.RED}Failed to ingest {failed} documents.{Style.RESET_ALL}") | 
|  | 174 | + | 
|  | 175 | +    def ingest_command(self, paths: List[str]): | 
|  | 176 | +        all_files = [] | 
|  | 177 | +        for path in paths: | 
|  | 178 | +            if os.path.isfile(path): | 
|  | 179 | +                all_files.append(path) | 
|  | 180 | +            elif os.path.isdir(path): | 
|  | 181 | +                all_files.extend(glob.glob(os.path.join(path, '*'))) | 
|  | 182 | +            else: | 
|  | 183 | +                print(f"{Fore.YELLOW}Invalid path: {path}{Style.RESET_ALL}") | 
|  | 184 | +         | 
|  | 185 | +        supported_extensions = ['.csv', '.txt', '.pdf'] | 
|  | 186 | +        valid_files = [f for f in all_files if any(f.lower().endswith(ext) for ext in supported_extensions)] | 
|  | 187 | +         | 
|  | 188 | +        if not valid_files: | 
|  | 189 | +            print(f"{Fore.RED}No valid files found for ingestion.{Style.RESET_ALL}") | 
|  | 190 | +            return | 
|  | 191 | +         | 
|  | 192 | +        print(f"{Fore.GREEN}Found {len(valid_files)} valid files for ingestion.{Style.RESET_ALL}") | 
|  | 193 | +         | 
|  | 194 | +        self.process_and_ingest_data(valid_files) | 
0 commit comments