-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRAG.py
More file actions
123 lines (102 loc) · 4.36 KB
/
RAG.py
File metadata and controls
123 lines (102 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from pathlib import Path
from sentence_transformers import SentenceTransformer
import faiss
import pickle
import fitz # PyMuPDF
def load_text_files(folder_path):
text_data = []
for file_path in Path(folder_path).rglob("*"):
if file_path.suffix.lower() == '.txt' or file_path.suffix.lower() == '.md':
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
text_data.append((str(file_path), f.read()))
elif file_path.suffix.lower() == '.pdf':
with fitz.open(file_path) as doc:
content = ""
for page in doc:
content += page.get_text()
text_data.append((str(file_path), content))
# elif file_path.suffix.lower() == '.png': # for png files of text (receipts, etc.)
# # FIXME & ADD @ ~ line 90
# with open(file_path, 'rb') as f:
# content = f.read()
# text_data.append((str(file_path), content))
return text_data
def chunk_text(text, chunk_size=500, overlap=50):
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i:i+chunk_size])
chunks.append(chunk)
return chunks
model_path = "./models/all-MiniLM-L6-v2" # . for source and .. for build
embedder = SentenceTransformer(model_path)
def embed_chunks(chunks):
return embedder.encode(chunks, convert_to_numpy=True)
def build_faiss_index(embeddings, chunks, metadata):
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
faiss.write_index(index, "rag_index.faiss")
# Only save metadata with file positions, not full chunks
with open("rag_metadata.pkl", "wb") as f:
pickle.dump({'metadata': metadata}, f)
def build_embeddings(folder_path):
text_data = load_text_files(folder_path)
if not text_data:
print("No text files found!")
return False
all_chunks = []
all_metadata = []
for file_path, content in text_data:
chunks = chunk_text(content)
for idx, chunk in enumerate(chunks):
all_chunks.append(chunk)
all_metadata.append((file_path, idx))
if not all_chunks:
print("No chunks created from files!")
return False
embeddings = embed_chunks(all_chunks)
build_faiss_index(embeddings, all_chunks, all_metadata)
return True
def retrieve_relevant_chunks(query, embedder, index, data, top_k=2):
query_embedding = embedder.encode([query])
D, I = index.search(query_embedding, top_k)
chunks = []
metadata = data['metadata']
for i in I[0]:
file_path, chunk_idx = metadata[i]
# Re-chunk for exact file
try:
if file_path.endswith('.txt') or file_path.endswith('.md'):
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
file_chunks = chunk_text(content)
if chunk_idx < len(file_chunks):
chunks.append(file_chunks[chunk_idx])
elif file_path.endswith('.pdf'):
with fitz.open(file_path) as doc:
content = ""
for page in doc:
content += page.get_text()
file_chunks = chunk_text(content)
if chunk_idx < len(file_chunks):
chunks.append(file_chunks[chunk_idx])
# TODO: Add the png support here
except Exception as e:
print(f"Error reading {file_path}: {e}")
return chunks
def build_prompt(retrieved_chunks, user_question):
context = "\n\n".join(retrieved_chunks)
return f"Here is some context about the user:\n\n{context}. You may not= need to use this information to answer the question; it's just to provide more context.\n\nHere is the question: {user_question}"
def load_faiss_index_and_metadata():
try:
index = faiss.read_index("rag_index.faiss")
with open("rag_metadata.pkl", "rb") as f:
metadata = pickle.load(f)
return index, metadata
except RuntimeError as e:
print("Add some user context")
return None, None
except Exception as e:
print(f"Error loading FAISS index or metadata: {e}")
return None, None