diff --git a/RAG/README b/RAG/README new file mode 100644 index 0000000..34131ff --- /dev/null +++ b/RAG/README @@ -0,0 +1,2 @@ +Get GEMINI API KEY from google console +For Database query you can alter the prompts according to your schema diff --git a/RAG/VectorDB/__pycache__/client.cpython-312.pyc b/RAG/VectorDB/__pycache__/client.cpython-312.pyc new file mode 100644 index 0000000..3a71abe Binary files /dev/null and b/RAG/VectorDB/__pycache__/client.cpython-312.pyc differ diff --git a/RAG/VectorDB/client.py b/RAG/VectorDB/client.py new file mode 100644 index 0000000..a5dc4a7 --- /dev/null +++ b/RAG/VectorDB/client.py @@ -0,0 +1,26 @@ +import chromadb +from chromadb.utils import embedding_functions +def get_chroma_client(): + client = chromadb.PersistentClient(path="/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/data") + return client + +def get_or_create_collections(client, collection_name, model_name='all-mpnet-base-v2'): + sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) + collections = client.list_collections() + collection_names = [c.name for c in collections] + + if collection_name not in collection_names: + collection = client.create_collection(name=collection_name) + else: + collection = client.get_collection(name=collection_name) + return collection + +def delete_all_collections(): + client = get_chroma_client() + collections = client.list_collections() + for collection in collections: + client.delete_collection(collection.name) + print(f"Deleted collection: {collection.name}") + +if __name__ == '__main__': + delete_all_collections() \ No newline at end of file diff --git a/RAG/__pycache__/database_handler.cpython-312.pyc b/RAG/__pycache__/database_handler.cpython-312.pyc new file mode 100644 index 0000000..a8f5414 Binary files /dev/null and b/RAG/__pycache__/database_handler.cpython-312.pyc differ diff --git a/RAG/__pycache__/preprocessing.cpython-312.pyc b/RAG/__pycache__/preprocessing.cpython-312.pyc new file mode 100644 index 0000000..ac22df3 Binary files /dev/null and b/RAG/__pycache__/preprocessing.cpython-312.pyc differ diff --git a/RAG/__pycache__/rag_pipeline.cpython-312.pyc b/RAG/__pycache__/rag_pipeline.cpython-312.pyc new file mode 100644 index 0000000..57969fe Binary files /dev/null and b/RAG/__pycache__/rag_pipeline.cpython-312.pyc differ diff --git a/RAG/database_handler.py b/RAG/database_handler.py new file mode 100644 index 0000000..003d138 --- /dev/null +++ b/RAG/database_handler.py @@ -0,0 +1,112 @@ +# database_handler.py +import os +import sqlite3 +import google.generativeai as genai +from dotenv import load_dotenv +load_dotenv() +api_key=os.getenv("GOOGLE_API_KEY") +genai.configure(api_key=api_key) +model = genai.GenerativeModel('gemini-1.5-flash') + +# System prompt for the Gemini model +system_prompt = """ +I have an sqlite database with the following tables and columns: + +Table name: RatePlan +Columns: +RatePlanId INTEGER PRIMARY KEY +Name VARCHAR(255) +MonthlyFee FLOAT +CallRate FLOAT +SmsRate FLOAT +DataRate FLOAT + + +Table name: Customer +Columns: +CustomerId INTEGER PRIMARY KEY +FirstName VARCHAR(255) +LastName VARCHAR(255) +Address VARCHAR(255) +City VARCHAR(255) +State VARCHAR(255) +Country VARCHAR(255) +PostalCode VARCHAR(255) +Phone VARCHAR(255) +Email VARCHAR(255) +RatePlanId INT +ContractStart DATE +ContractEnd DATE + +Foreign Keys: +Foreign key: RatePlanId references RatePlanId(NO ACTION) + +Table name: Phone +Columns: +PhoneId INTEGER PRIMARY KEY +Brand VARCHAR(255) +Model VARCHAR(255) +OS VARCHAR(255) +Price FLOAT + +Table name: CustomerPhone +Columns: +CustomerPhoneId INTEGER PRIMARY KEY +CustomerId INT +PhoneId INT +PhoneAcquisitionDate DATE + +Foreign Keys: +Foreign key: PhoneId references PhoneId(NO ACTION) +Foreign key: CustomerId references CustomerId(NO ACTION) + +Table name: CDR +Columns: +CdrId INTEGER PRIMARY KEY +CustomerId INT +PhoneNumber VARCHAR(255) +CallDateTime DATETIME +CallType VARCHAR(255) +DurationInSeconds INT +DataUsageKb INT +SmsCount INT + +Foreign Keys: +Foreign key: CustomerId references CustomerId(NO ACTION) + +I will need you to help me generate SQL queries to get data from my database. +Please respond only with the query in simple text format. Do not provide any explanations or additional text. + +If the user tries to modify the database respond with 'ERROR: cannot modify db' +""" + +# Initialize chat with system prompt +model = genai.GenerativeModel('gemini-1.5-flash') +chat = model.start_chat(history=[]) +chat.send_message(system_prompt) + +def generate_sql_query(prompt): + response = chat.send_message(prompt) + sql_query = response.text.strip() + if(sql_query=="ERROR: cannot modify db"):return "ERROR: cannot modify db" + sql_query = sql_query.replace('```sql', '').replace('```', '').strip() + print(sql_query) + return sql_query + +def fetch_data_from_db(sql_query, db_path='/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/call_db.sqlite'): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + + cursor.execute(sql_query) + results = cursor.fetchall() + columns = [description[0] for description in cursor.description] + conn.close() + return columns, results + +def format_results_as_table(columns, results): + table = [columns] + table.extend(results) + return table + + diff --git a/RAG/preprocessing.py b/RAG/preprocessing.py new file mode 100644 index 0000000..6eff7d1 --- /dev/null +++ b/RAG/preprocessing.py @@ -0,0 +1,29 @@ +import nltk +from nltk.tokenize import sent_tokenize +from PyPDF2 import PdfReader +from io import BytesIO +from docx import Document + +nltk.download('punkt') + +def extract_text_from_pdf(pdf_file): + reader = PdfReader(pdf_file) + text = '' + for page in reader.pages: + text += page.extract_text() + '\n' + return text + +def extract_text_from_txt(txt_file_path): + with open(txt_file_path, 'r', encoding='utf-8') as f: + text = f.read() + return text + +def extract_text_from_docx(docx_file): + doc = Document(docx_file) + paragraphs = [paragraph.text for paragraph in doc.paragraphs] + text = '\n'.join(paragraphs) + return text + +def chunk_text(text, chunk_size=5): + sentences = sent_tokenize(text) + return [' '.join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)] diff --git a/RAG/rag_pipeline.py b/RAG/rag_pipeline.py new file mode 100644 index 0000000..97182f0 --- /dev/null +++ b/RAG/rag_pipeline.py @@ -0,0 +1,64 @@ +from VectorDB.client import get_chroma_client +import os +import google.generativeai as genai +from chromadb.utils import embedding_functions +import uuid +from preprocessing import extract_text_from_pdf, chunk_text,extract_text_from_docx,extract_text_from_txt +from VectorDB.client import get_chroma_client, get_or_create_collections +from dotenv import load_dotenv +load_dotenv() +api_key=os.getenv("GOOGLE_API_KEY") +genai.configure(api_key=api_key) +model = genai.GenerativeModel('gemini-1.5-flash') +chat = model.start_chat(history=[]) + +def store_document_embeddings(document_path, collection_name='document', chunk_size=5, model_name='all-mpnet-base-v2'): + client = get_chroma_client() + + collection = get_or_create_collections(client, collection_name, model_name) + sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2') + _, file_extension = os.path.splitext(document_path) + if file_extension == '.pdf': + text = extract_text_from_pdf(document_path) + elif file_extension == '.docx': + text = extract_text_from_docx(document_path) + elif file_extension == '.txt': + text= extract_text_from_txt(document_path) + else: + raise ValueError(f"Unsupported file type: {file_extension}. Supported types are pdf, docx, txt.") + + chunks = chunk_text(text, chunk_size) + ids = [str(uuid.uuid4()) for _ in range(len(chunks))] + metadata = [{"document": document_path, "chunk": i} for i in range(len(chunks))] + + collection.add(ids=ids, documents=chunks, metadatas=metadata) + + +def retrieve_documents(query, collection): + + sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2') + + results = collection.query(query_texts=[query], n_results=5,) + return results + +def generate_response(query, collection_name='document'): + client = get_chroma_client() + collections=client.list_collections() + expansion_prompt = f"Expand or transform the following query to include related keywords and phrases that will improve the chances of finding relevant text in a document database:\n\nQuery: {query}\n\nExpanded query:" + expanded_query = model.generate_content(expansion_prompt).text.strip() + context = "" + for collection in collections: + documents = retrieve_documents(expanded_query,collection) + for document in documents["documents"]: + for i in document : + context+=i + + + prompt = f"User query: {query}\n\nRelevant information:\n{context}\n\nBased on the above information, please provide a detailed response.Dont add on your own anything just stick to the info given and answer the query without any suggestions or recommendations" + + chats= chat.send_message(prompt) + # Generate the response in streaming mode + response = chats.text + + print(type(chat.history[1])) + return response diff --git a/RAG/system_prompt.py b/RAG/system_prompt.py new file mode 100644 index 0000000..89ff0e9 --- /dev/null +++ b/RAG/system_prompt.py @@ -0,0 +1,58 @@ +import sqlite3 + +def get_database_schema(db_path): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Get all table names + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + schema_details = [] + + for table in tables: + table_name = table[0] + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + + column_details = [] + for col in columns: + column_info = f"{col[1]} {col[2]}" + if col[5]: # Check if the column is a primary key + column_info += " PRIMARY KEY" + column_details.append(column_info) + + # Get foreign key information + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + foreign_keys = cursor.fetchall() + foreign_key_details = [] + for fk in foreign_keys: + foreign_key_details.append( + f"Foreign key: {fk[3]} references {fk[4]}({fk[5]})" + ) + + schema_details.append(f"Table name: {table_name}\nColumns:\n" + "\n".join(column_details)) + if foreign_key_details: + schema_details.append("Foreign Keys:\n" + "\n".join(foreign_key_details)) + + conn.close() + + return "\n\n".join(schema_details) + +def generate_system_prompt(db_path): + schema_details = get_database_schema(db_path) + system_prompt = f"""I have an sqlite database with the following tables and columns: + +{schema_details} + +I will need you to help me generate SQL queries to get data from my database. +Please respond only with the query. Do not provide any explanations or additional text. + +If the user tries to modify the database respond with 'ERROR: cannot modify db' +""" + return system_prompt + +# Example usage +db_path = 'call_db.sqlite' # Replace with your actual database path +system_prompt = generate_system_prompt(db_path) +print(system_prompt) diff --git a/RAG/ui.py b/RAG/ui.py new file mode 100644 index 0000000..cdfb38e --- /dev/null +++ b/RAG/ui.py @@ -0,0 +1,134 @@ +import streamlit as st +from rag_pipeline import generate_response, store_document_embeddings +from preprocessing import extract_text_from_pdf, chunk_text +import os +import tempfile +import re +from database_handler import generate_sql_query,fetch_data_from_db,format_results_as_table +def sanitize_collection_name(name): + name = re.sub(r'[^\w-]', '', name) + + + if not name[0].isalnum(): + name = '_' + name + if not name[-1].isalnum(): + name = name + '_' + + name = name[:63] + + return name + +# Page title and configuration +st.set_page_config(page_title="Chat Interface", page_icon="🤖") + +# Initialize session state for chat history and uploaded file state +if 'history' not in st.session_state: + st.session_state.history = [] + +if 'uploaded_file' not in st.session_state: + st.session_state.uploaded_file = None + +if "messages" not in st.session_state: + st.session_state.messages=[] + +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +# # Function to display chat messages in a chat message container style +# def display_chat_messages(messages): +# for idx, message in enumerate(messages): +# if 'role' in message and 'content' in message: +# if message['role'] == 'user': +# st.info(f"You: {message['content']}") +# elif message['role'] == 'bot': +# st.success(f"Bot: {message['content']}") + +# Sidebar for File Upload and Embeddings +with st.sidebar: + + st.title('Ericsson Chat Interface') + + # File uploader for PDFs, DOCX, and TXT files + uploaded_file = st.file_uploader("Upload a File", type=["pdf", "docx", "txt"]) + if uploaded_file is not None: + # Create a temporary directory if it doesn't exist + temp_dir = os.path.join('data', 'temp') + os.makedirs(temp_dir, exist_ok=True) + + # Save uploaded file to temporary directory + collection_name=sanitize_collection_name(uploaded_file.name) + temp_file_path = os.path.join(temp_dir, uploaded_file.name) + with open(temp_file_path, 'wb') as f: + f.write(uploaded_file.getbuffer()) + st.success("File uploaded successfully!") + + # Update session state with uploaded file + st.session_state.uploaded_file = temp_file_path + db_query=st.checkbox("DB Query") + +# Automatically process and store embeddings when a file is uploaded +if st.session_state.uploaded_file: + + store_document_embeddings(st.session_state.uploaded_file,collection_name=collection_name, chunk_size=5) + + os.remove(st.session_state.uploaded_file) + st.session_state.uploaded_file = None + +# Container for the conversation +conversation_area = st.empty() + +# User input for query (moved to sidebar) +with st.sidebar: + query = st.chat_input('Enter your query:', key='query_input') + + + # Handle user query submission + +if db_query and query: + + query_sql=generate_sql_query(query) + if(query_sql=="ERROR: cannot modify db"): + + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + with st.chat_message("assistant"): + st.markdown(query_sql) + st.session_state.messages.append({"role":"assistant","content":query_sql}) + else: + columns,results=fetch_data_from_db(query_sql) + table=format_results_as_table(columns,results) + + + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + with st.chat_message("assistant"): + st.markdown('Table') + st.table(table) + # st.session_state.messages.append({"role":"assistant","content":results}) + + + +elif query: + # Add user query to history + # st.session_state.history.append({'role': 'user', 'content': query}) + # Generate bot response + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + response = generate_response(query) + with st.chat_message("assistant"): + st.markdown(response) + st.session_state.messages.append({"role":"assistant","content":response}) + + +# # Display updated chat messages +# display_chat_messages(st.session_state.history)