Skip to content
Closed

RAG #97

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RAG/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Get GEMINI API KEY from google console
For Database query you can alter the prompts according to your schema
Binary file added RAG/VectorDB/__pycache__/client.cpython-312.pyc
Binary file not shown.
26 changes: 26 additions & 0 deletions RAG/VectorDB/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import chromadb
from chromadb.utils import embedding_functions
def get_chroma_client():
client = chromadb.PersistentClient(path="/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/data")
return client

def get_or_create_collections(client, collection_name, model_name='all-mpnet-base-v2'):
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
collections = client.list_collections()
collection_names = [c.name for c in collections]

if collection_name not in collection_names:
collection = client.create_collection(name=collection_name)
else:
collection = client.get_collection(name=collection_name)
return collection

def delete_all_collections():
client = get_chroma_client()
collections = client.list_collections()
for collection in collections:
client.delete_collection(collection.name)
print(f"Deleted collection: {collection.name}")

if __name__ == '__main__':
delete_all_collections()
Binary file added RAG/__pycache__/database_handler.cpython-312.pyc
Binary file not shown.
Binary file added RAG/__pycache__/preprocessing.cpython-312.pyc
Binary file not shown.
Binary file added RAG/__pycache__/rag_pipeline.cpython-312.pyc
Binary file not shown.
112 changes: 112 additions & 0 deletions RAG/database_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# database_handler.py
import os
import sqlite3
import google.generativeai as genai
from dotenv import load_dotenv
load_dotenv()
api_key=os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-1.5-flash')

# System prompt for the Gemini model
system_prompt = """
I have an sqlite database with the following tables and columns:

Table name: RatePlan
Columns:
RatePlanId INTEGER PRIMARY KEY
Name VARCHAR(255)
MonthlyFee FLOAT
CallRate FLOAT
SmsRate FLOAT
DataRate FLOAT


Table name: Customer
Columns:
CustomerId INTEGER PRIMARY KEY
FirstName VARCHAR(255)
LastName VARCHAR(255)
Address VARCHAR(255)
City VARCHAR(255)
State VARCHAR(255)
Country VARCHAR(255)
PostalCode VARCHAR(255)
Phone VARCHAR(255)
Email VARCHAR(255)
RatePlanId INT
ContractStart DATE
ContractEnd DATE

Foreign Keys:
Foreign key: RatePlanId references RatePlanId(NO ACTION)

Table name: Phone
Columns:
PhoneId INTEGER PRIMARY KEY
Brand VARCHAR(255)
Model VARCHAR(255)
OS VARCHAR(255)
Price FLOAT

Table name: CustomerPhone
Columns:
CustomerPhoneId INTEGER PRIMARY KEY
CustomerId INT
PhoneId INT
PhoneAcquisitionDate DATE

Foreign Keys:
Foreign key: PhoneId references PhoneId(NO ACTION)
Foreign key: CustomerId references CustomerId(NO ACTION)

Table name: CDR
Columns:
CdrId INTEGER PRIMARY KEY
CustomerId INT
PhoneNumber VARCHAR(255)
CallDateTime DATETIME
CallType VARCHAR(255)
DurationInSeconds INT
DataUsageKb INT
SmsCount INT

Foreign Keys:
Foreign key: CustomerId references CustomerId(NO ACTION)

I will need you to help me generate SQL queries to get data from my database.
Please respond only with the query in simple text format. Do not provide any explanations or additional text.

If the user tries to modify the database respond with 'ERROR: cannot modify db'
"""

# Initialize chat with system prompt
model = genai.GenerativeModel('gemini-1.5-flash')
chat = model.start_chat(history=[])
chat.send_message(system_prompt)

def generate_sql_query(prompt):
response = chat.send_message(prompt)
sql_query = response.text.strip()
if(sql_query=="ERROR: cannot modify db"):return "ERROR: cannot modify db"
sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
print(sql_query)
return sql_query

def fetch_data_from_db(sql_query, db_path='/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/call_db.sqlite'):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()


cursor.execute(sql_query)
results = cursor.fetchall()
columns = [description[0] for description in cursor.description]
conn.close()
return columns, results

def format_results_as_table(columns, results):
table = [columns]
table.extend(results)
return table


29 changes: 29 additions & 0 deletions RAG/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import nltk
from nltk.tokenize import sent_tokenize
from PyPDF2 import PdfReader
from io import BytesIO
from docx import Document

nltk.download('punkt')

def extract_text_from_pdf(pdf_file):
reader = PdfReader(pdf_file)
text = ''
for page in reader.pages:
text += page.extract_text() + '\n'
return text

def extract_text_from_txt(txt_file_path):
with open(txt_file_path, 'r', encoding='utf-8') as f:
text = f.read()
return text

def extract_text_from_docx(docx_file):
doc = Document(docx_file)
paragraphs = [paragraph.text for paragraph in doc.paragraphs]
text = '\n'.join(paragraphs)
return text

def chunk_text(text, chunk_size=5):
sentences = sent_tokenize(text)
return [' '.join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]
64 changes: 64 additions & 0 deletions RAG/rag_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from VectorDB.client import get_chroma_client
import os
import google.generativeai as genai
from chromadb.utils import embedding_functions
import uuid
from preprocessing import extract_text_from_pdf, chunk_text,extract_text_from_docx,extract_text_from_txt
from VectorDB.client import get_chroma_client, get_or_create_collections
from dotenv import load_dotenv
load_dotenv()
api_key=os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-1.5-flash')
chat = model.start_chat(history=[])

def store_document_embeddings(document_path, collection_name='document', chunk_size=5, model_name='all-mpnet-base-v2'):
client = get_chroma_client()

collection = get_or_create_collections(client, collection_name, model_name)
sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2')
_, file_extension = os.path.splitext(document_path)
if file_extension == '.pdf':
text = extract_text_from_pdf(document_path)
elif file_extension == '.docx':
text = extract_text_from_docx(document_path)
elif file_extension == '.txt':
text= extract_text_from_txt(document_path)
else:
raise ValueError(f"Unsupported file type: {file_extension}. Supported types are pdf, docx, txt.")

chunks = chunk_text(text, chunk_size)
ids = [str(uuid.uuid4()) for _ in range(len(chunks))]
metadata = [{"document": document_path, "chunk": i} for i in range(len(chunks))]

collection.add(ids=ids, documents=chunks, metadatas=metadata)


def retrieve_documents(query, collection):

sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2')

results = collection.query(query_texts=[query], n_results=5,)
return results

def generate_response(query, collection_name='document'):
client = get_chroma_client()
collections=client.list_collections()
expansion_prompt = f"Expand or transform the following query to include related keywords and phrases that will improve the chances of finding relevant text in a document database:\n\nQuery: {query}\n\nExpanded query:"
expanded_query = model.generate_content(expansion_prompt).text.strip()
context = ""
for collection in collections:
documents = retrieve_documents(expanded_query,collection)
for document in documents["documents"]:
for i in document :
context+=i


prompt = f"User query: {query}\n\nRelevant information:\n{context}\n\nBased on the above information, please provide a detailed response.Dont add on your own anything just stick to the info given and answer the query without any suggestions or recommendations"

chats= chat.send_message(prompt)
# Generate the response in streaming mode
response = chats.text

print(type(chat.history[1]))
return response
58 changes: 58 additions & 0 deletions RAG/system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sqlite3

def get_database_schema(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

# Get all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()

schema_details = []

for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()

column_details = []
for col in columns:
column_info = f"{col[1]} {col[2]}"
if col[5]: # Check if the column is a primary key
column_info += " PRIMARY KEY"
column_details.append(column_info)

# Get foreign key information
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
foreign_keys = cursor.fetchall()
foreign_key_details = []
for fk in foreign_keys:
foreign_key_details.append(
f"Foreign key: {fk[3]} references {fk[4]}({fk[5]})"
)

schema_details.append(f"Table name: {table_name}\nColumns:\n" + "\n".join(column_details))
if foreign_key_details:
schema_details.append("Foreign Keys:\n" + "\n".join(foreign_key_details))

conn.close()

return "\n\n".join(schema_details)

def generate_system_prompt(db_path):
schema_details = get_database_schema(db_path)
system_prompt = f"""I have an sqlite database with the following tables and columns:

{schema_details}

I will need you to help me generate SQL queries to get data from my database.
Please respond only with the query. Do not provide any explanations or additional text.

If the user tries to modify the database respond with 'ERROR: cannot modify db'
"""
return system_prompt

# Example usage
db_path = 'call_db.sqlite' # Replace with your actual database path
system_prompt = generate_system_prompt(db_path)
print(system_prompt)
Loading
Loading