-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathingestion.py
More file actions
177 lines (129 loc) · 7.32 KB
/
ingestion.py
File metadata and controls
177 lines (129 loc) · 7.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import time
import argparse
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
from util import ( encode_pdf, show_context ,
build_knowledge_graph,rerank_documents, find_node_by_content,
expand_context_via_graph,visualize_graph)
load_dotenv()
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
class GraphRAG:
def __init__(self,path , chunk_size=1000,chunk_overlap=200 , n_retrieved=10, force_rebuild=False):
print("INGESTION PHASE")
self.llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0
)
start_time = time.time()
self.vector_store,self.splits,self.embedding_model = encode_pdf(
path, chunk_size=chunk_size, chunk_overlap=chunk_overlap, force_rebuild=force_rebuild
)
self.time_records= {'FAISS Indexing': time.time() - start_time}
print(f"FAISS time: {self.time_records['FAISS Indexing']:.2f} seconds")
start_time = time.time()
self.knowledge_graph=build_knowledge_graph(
self.splits,self.llm,self.embedding_model, force_rebuild=force_rebuild
)
self.time_records['graph building'] = time.time() - start_time
print(f" Graph time: {self.time_records['graph building']:.2f}s")
self.n_retrieved=n_retrieved
self.chunks_query_retriever =self.vector_store.as_retriever(search_kwargs={"k":n_retrieved})
print(f"\nIngestion complete. Ready for queries")
def run(self,query):
print("\n[1] Query Rewriting")
start_time=time.time()
query_rewrite_template = """You are an AI assistant tasked with reformulating user queries to improve retrieval in a RAG system.
Given the original query, rewrite it to be more specific, detailed, and likely to retrieve relevant information.Dont give anything else except the rewritten query
Original query: {query}
Rewritten query:"""
query_rewrite_prompt = PromptTemplate(
input_variables=["query"],
template=query_rewrite_template
)
query_rewriter = query_rewrite_prompt | self.llm
changed_query=query_rewriter.invoke(query).content
print(f"changed query :{changed_query}")
self.time_records['query rewrite'] = time.time()-start_time
print(f"query rewriting time:{self.time_records['query rewrite']:.2f}s")
print(f"\n[2] vector retrieval (top {self.n_retrieved} from FAISS)")
start_time= time.time()
retrieved_docs = self.chunks_query_retriever.invoke(changed_query)
print(f"Retrieved {len(retrieved_docs)} chunks")
self.time_records['vector retrieval']=time.time()-start_time
print(f"vector retrieval time:{self.time_records['vector retrieval']:.2f}s")
n_rerank=min(5,len(retrieved_docs))
print(f"\n[3] cross-encoder reranking (top {n_rerank})")
start_time=time.time()
ranked_results=rerank_documents(changed_query,retrieved_docs,n_retrieved=n_rerank)
self.time_records['reranking']=time.time()-start_time
print(f"reranking time:{self.time_records['reranking']:.2f}s")
print(f"\n[4] graph expansion (dijkstra traversal)")
start_time=time.time()
seed_nodes=[]
for doc,score in ranked_results:
node_idx=find_node_by_content(self.knowledge_graph.graph,doc.page_content)
if node_idx is not None:
seed_nodes.append((node_idx,score))
if seed_nodes:
context_texts,traversal_path=expand_context_via_graph(
self.knowledge_graph,seed_nodes,max_nodes=8
)
if traversal_path:
visualize_graph(self.knowledge_graph,traversal_path)
else:
print("no graph nodes matched. Using reraked docs directly")
context_texts=[doc.page_content for doc, _ in ranked_results]
traversal_path=[]
self.time_records['graph expansion'] =time.time()-start_time
print(f"graph expansion time:{self.time_records['graph expansion']:.2f}s")
# print(f"\n[5] final context({len(context_texts)} chunks)")
# show_context(context_texts)
print(f"\n[6] answer generation")
start_time=time.time()
context_text="\n\n".join(context_texts)
prompt = f"Based on the following context, answer the question.\n\nContext:\n{context_text}\n\nQuestion: {changed_query}\n\nAnswer:"
response = self.llm.invoke(prompt).content
self.time_records['answer generation']=time.time()-start_time
print(f"answer generation time:{self.time_records['answer generation']:.2f}s")
print(f"\nAnswer: {response}")
def validate_args(args):
if args.chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if args.chunk_overlap < 0:
raise ValueError("chunk_overlap must be a non-negative integer.")
if args.n_retrieved<=0:
raise ValueError("n_retrieved must be a positive integer.")
return args
# argparse is a file(module) and ArgumentParser is a class inside this file , this means parser is a object of the class
# add_argument is a method of class ArgumentParser
def parse_args():
parser = argparse.ArgumentParser(description="RAG with knowledge graph")
parser.add_argument("--path",type=str,default="",help="path to the pdf file to encode")
parser.add_argument("--chunk_size",type=int,default=1000,help="size of each text chunk(default: 1000).")
parser.add_argument("--chunk_overlap",type=int,default=200,help="overlap between consecutive chuncks (default : 200).")
parser.add_argument("--n_retrieved" , type=int , default=10 , help="number of chunks to retrieve for each query (default: 2).")
parser.add_argument("--query",type=str,default="what is the main cause of climate change?", help="query to test the retriever (default: 'what is the main cause of climate change?).")
parser.add_argument("--evaluate", action="store_true", help="whether to evaluate the retriver's performance (default: false)")
parser.add_argument("--rebuild", action="store_true",help="Force rebuild indexes even if persisted ones exist")
return validate_args(parser.parse_args())
# first main(parse_args()) which runs parse_args and validate args and returs args to main function
# in main function we call object of class SimpleRAG , object : simple_rag
# we call method of class : run
def main(args):
force_rebuild = args.rebuild
if force_rebuild:
print("Force rebuild enabled. Will rebuild indexes from scratch.")
simple_rag = GraphRAG(
path=args.path,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
n_retrieved=args.n_retrieved,
force_rebuild=force_rebuild
)
simple_rag.run(args.query)
# if args.evaluate:
# evaluate_rag(simple_rag.chunks_query_retriever)
if __name__ == '__main__':
main(parse_args())