Skip to content

Commit aabb015

Browse files
committed
RAG
1 parent 1347001 commit aabb015

File tree

11 files changed

+425
-0
lines changed

11 files changed

+425
-0
lines changed

RAG/README

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Get GEMINI API KEY from google console
2+
For Database query you can alter the prompts according to your schema
1.71 KB
Binary file not shown.

RAG/VectorDB/client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import chromadb
2+
from chromadb.utils import embedding_functions
3+
def get_chroma_client():
4+
client = chromadb.PersistentClient(path="/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/data")
5+
return client
6+
7+
def get_or_create_collections(client, collection_name, model_name='all-mpnet-base-v2'):
8+
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
9+
collections = client.list_collections()
10+
collection_names = [c.name for c in collections]
11+
12+
if collection_name not in collection_names:
13+
collection = client.create_collection(name=collection_name)
14+
else:
15+
collection = client.get_collection(name=collection_name)
16+
return collection
17+
18+
def delete_all_collections():
19+
client = get_chroma_client()
20+
collections = client.list_collections()
21+
for collection in collections:
22+
client.delete_collection(collection.name)
23+
print(f"Deleted collection: {collection.name}")
24+
25+
if __name__ == '__main__':
26+
delete_all_collections()
3.58 KB
Binary file not shown.
1.83 KB
Binary file not shown.
4.05 KB
Binary file not shown.

RAG/database_handler.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# database_handler.py
2+
import os
3+
import sqlite3
4+
import google.generativeai as genai
5+
from dotenv import load_dotenv
6+
load_dotenv()
7+
api_key=os.getenv("GOOGLE_API_KEY")
8+
genai.configure(api_key=api_key)
9+
model = genai.GenerativeModel('gemini-1.5-flash')
10+
11+
# System prompt for the Gemini model
12+
system_prompt = """
13+
I have an sqlite database with the following tables and columns:
14+
15+
Table name: RatePlan
16+
Columns:
17+
RatePlanId INTEGER PRIMARY KEY
18+
Name VARCHAR(255)
19+
MonthlyFee FLOAT
20+
CallRate FLOAT
21+
SmsRate FLOAT
22+
DataRate FLOAT
23+
24+
25+
Table name: Customer
26+
Columns:
27+
CustomerId INTEGER PRIMARY KEY
28+
FirstName VARCHAR(255)
29+
LastName VARCHAR(255)
30+
Address VARCHAR(255)
31+
City VARCHAR(255)
32+
State VARCHAR(255)
33+
Country VARCHAR(255)
34+
PostalCode VARCHAR(255)
35+
Phone VARCHAR(255)
36+
Email VARCHAR(255)
37+
RatePlanId INT
38+
ContractStart DATE
39+
ContractEnd DATE
40+
41+
Foreign Keys:
42+
Foreign key: RatePlanId references RatePlanId(NO ACTION)
43+
44+
Table name: Phone
45+
Columns:
46+
PhoneId INTEGER PRIMARY KEY
47+
Brand VARCHAR(255)
48+
Model VARCHAR(255)
49+
OS VARCHAR(255)
50+
Price FLOAT
51+
52+
Table name: CustomerPhone
53+
Columns:
54+
CustomerPhoneId INTEGER PRIMARY KEY
55+
CustomerId INT
56+
PhoneId INT
57+
PhoneAcquisitionDate DATE
58+
59+
Foreign Keys:
60+
Foreign key: PhoneId references PhoneId(NO ACTION)
61+
Foreign key: CustomerId references CustomerId(NO ACTION)
62+
63+
Table name: CDR
64+
Columns:
65+
CdrId INTEGER PRIMARY KEY
66+
CustomerId INT
67+
PhoneNumber VARCHAR(255)
68+
CallDateTime DATETIME
69+
CallType VARCHAR(255)
70+
DurationInSeconds INT
71+
DataUsageKb INT
72+
SmsCount INT
73+
74+
Foreign Keys:
75+
Foreign key: CustomerId references CustomerId(NO ACTION)
76+
77+
I will need you to help me generate SQL queries to get data from my database.
78+
Please respond only with the query in simple text format. Do not provide any explanations or additional text.
79+
80+
If the user tries to modify the database respond with 'ERROR: cannot modify db'
81+
"""
82+
83+
# Initialize chat with system prompt
84+
model = genai.GenerativeModel('gemini-1.5-flash')
85+
chat = model.start_chat(history=[])
86+
chat.send_message(system_prompt)
87+
88+
def generate_sql_query(prompt):
89+
response = chat.send_message(prompt)
90+
sql_query = response.text.strip()
91+
if(sql_query=="ERROR: cannot modify db"):return "ERROR: cannot modify db"
92+
sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
93+
print(sql_query)
94+
return sql_query
95+
96+
def fetch_data_from_db(sql_query, db_path='/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/call_db.sqlite'):
97+
conn = sqlite3.connect(db_path)
98+
cursor = conn.cursor()
99+
100+
101+
cursor.execute(sql_query)
102+
results = cursor.fetchall()
103+
columns = [description[0] for description in cursor.description]
104+
conn.close()
105+
return columns, results
106+
107+
def format_results_as_table(columns, results):
108+
table = [columns]
109+
table.extend(results)
110+
return table
111+
112+

RAG/preprocessing.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import nltk
2+
from nltk.tokenize import sent_tokenize
3+
from PyPDF2 import PdfReader
4+
from io import BytesIO
5+
from docx import Document
6+
7+
nltk.download('punkt')
8+
9+
def extract_text_from_pdf(pdf_file):
10+
reader = PdfReader(pdf_file)
11+
text = ''
12+
for page in reader.pages:
13+
text += page.extract_text() + '\n'
14+
return text
15+
16+
def extract_text_from_txt(txt_file_path):
17+
with open(txt_file_path, 'r', encoding='utf-8') as f:
18+
text = f.read()
19+
return text
20+
21+
def extract_text_from_docx(docx_file):
22+
doc = Document(docx_file)
23+
paragraphs = [paragraph.text for paragraph in doc.paragraphs]
24+
text = '\n'.join(paragraphs)
25+
return text
26+
27+
def chunk_text(text, chunk_size=5):
28+
sentences = sent_tokenize(text)
29+
return [' '.join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]

RAG/rag_pipeline.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from VectorDB.client import get_chroma_client
2+
import os
3+
import google.generativeai as genai
4+
from chromadb.utils import embedding_functions
5+
import uuid
6+
from preprocessing import extract_text_from_pdf, chunk_text,extract_text_from_docx,extract_text_from_txt
7+
from VectorDB.client import get_chroma_client, get_or_create_collections
8+
from dotenv import load_dotenv
9+
load_dotenv()
10+
api_key=os.getenv("GOOGLE_API_KEY")
11+
genai.configure(api_key=api_key)
12+
model = genai.GenerativeModel('gemini-1.5-flash')
13+
chat = model.start_chat(history=[])
14+
15+
def store_document_embeddings(document_path, collection_name='document', chunk_size=5, model_name='all-mpnet-base-v2'):
16+
client = get_chroma_client()
17+
18+
collection = get_or_create_collections(client, collection_name, model_name)
19+
sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2')
20+
_, file_extension = os.path.splitext(document_path)
21+
if file_extension == '.pdf':
22+
text = extract_text_from_pdf(document_path)
23+
elif file_extension == '.docx':
24+
text = extract_text_from_docx(document_path)
25+
elif file_extension == '.txt':
26+
text= extract_text_from_txt(document_path)
27+
else:
28+
raise ValueError(f"Unsupported file type: {file_extension}. Supported types are pdf, docx, txt.")
29+
30+
chunks = chunk_text(text, chunk_size)
31+
ids = [str(uuid.uuid4()) for _ in range(len(chunks))]
32+
metadata = [{"document": document_path, "chunk": i} for i in range(len(chunks))]
33+
34+
collection.add(ids=ids, documents=chunks, metadatas=metadata)
35+
36+
37+
def retrieve_documents(query, collection):
38+
39+
sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2')
40+
41+
results = collection.query(query_texts=[query], n_results=5,)
42+
return results
43+
44+
def generate_response(query, collection_name='document'):
45+
client = get_chroma_client()
46+
collections=client.list_collections()
47+
expansion_prompt = f"Expand or transform the following query to include related keywords and phrases that will improve the chances of finding relevant text in a document database:\n\nQuery: {query}\n\nExpanded query:"
48+
expanded_query = model.generate_content(expansion_prompt).text.strip()
49+
context = ""
50+
for collection in collections:
51+
documents = retrieve_documents(expanded_query,collection)
52+
for document in documents["documents"]:
53+
for i in document :
54+
context+=i
55+
56+
57+
prompt = f"User query: {query}\n\nRelevant information:\n{context}\n\nBased on the above information, please provide a detailed response.Dont add on your own anything just stick to the info given and answer the query without any suggestions or recommendations"
58+
59+
chats= chat.send_message(prompt)
60+
# Generate the response in streaming mode
61+
response = chats.text
62+
63+
print(type(chat.history[1]))
64+
return response

RAG/system_prompt.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import sqlite3
2+
3+
def get_database_schema(db_path):
4+
conn = sqlite3.connect(db_path)
5+
cursor = conn.cursor()
6+
7+
# Get all table names
8+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
9+
tables = cursor.fetchall()
10+
11+
schema_details = []
12+
13+
for table in tables:
14+
table_name = table[0]
15+
cursor.execute(f"PRAGMA table_info({table_name});")
16+
columns = cursor.fetchall()
17+
18+
column_details = []
19+
for col in columns:
20+
column_info = f"{col[1]} {col[2]}"
21+
if col[5]: # Check if the column is a primary key
22+
column_info += " PRIMARY KEY"
23+
column_details.append(column_info)
24+
25+
# Get foreign key information
26+
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
27+
foreign_keys = cursor.fetchall()
28+
foreign_key_details = []
29+
for fk in foreign_keys:
30+
foreign_key_details.append(
31+
f"Foreign key: {fk[3]} references {fk[4]}({fk[5]})"
32+
)
33+
34+
schema_details.append(f"Table name: {table_name}\nColumns:\n" + "\n".join(column_details))
35+
if foreign_key_details:
36+
schema_details.append("Foreign Keys:\n" + "\n".join(foreign_key_details))
37+
38+
conn.close()
39+
40+
return "\n\n".join(schema_details)
41+
42+
def generate_system_prompt(db_path):
43+
schema_details = get_database_schema(db_path)
44+
system_prompt = f"""I have an sqlite database with the following tables and columns:
45+
46+
{schema_details}
47+
48+
I will need you to help me generate SQL queries to get data from my database.
49+
Please respond only with the query. Do not provide any explanations or additional text.
50+
51+
If the user tries to modify the database respond with 'ERROR: cannot modify db'
52+
"""
53+
return system_prompt
54+
55+
# Example usage
56+
db_path = 'call_db.sqlite' # Replace with your actual database path
57+
system_prompt = generate_system_prompt(db_path)
58+
print(system_prompt)

0 commit comments

Comments
 (0)