Skip to content

Commit f082c04

Browse files
committed
feat(rag): implement Chain of Thought prompting in RAG agents - Add use_cot parameter - Enhance prompt templates - Support CoT in both agents
1 parent 094c8c4 commit f082c04

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

agentic_rag/local_rag_agent.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ class QueryAnalysis(BaseModel):
2929
)
3030

3131
class LocalRAGAgent:
32-
def __init__(self, vector_store: VectorStore, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2"):
32+
def __init__(self, vector_store: VectorStore, model_name: str = "mistralai/Mistral-7B-Instruct-v0.2", use_cot: bool = False):
3333
"""Initialize local RAG agent with vector store and local LLM"""
3434
self.vector_store = vector_store
35+
self.use_cot = use_cot
3536

3637
# Load HuggingFace token from config
3738
try:
@@ -120,7 +121,15 @@ def _generate_direct_response(self, query: str) -> Dict[str, Any]:
120121
"""Generate a response directly from the LLM without context"""
121122
logger.info("Generating direct response from LLM without context...")
122123

123-
prompt = f"""You are a helpful AI assistant. Please answer the following query to the best of your ability.
124+
if self.use_cot:
125+
prompt = f"""You are a helpful AI assistant. Please answer the following query using chain of thought reasoning.
126+
First break down the problem into steps, then solve each step to arrive at the final answer.
127+
128+
Query: {query}
129+
130+
Let's think about this step by step:"""
131+
else:
132+
prompt = f"""You are a helpful AI assistant. Please answer the following query to the best of your ability.
124133
If you're not confident about the answer, please say so.
125134
126135
Query: {query}
@@ -199,7 +208,19 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[
199208
for i, item in enumerate(context)])
200209

201210
logger.info("Building prompt with context...")
202-
prompt = f"""Answer the following query using the provided context.
211+
if self.use_cot:
212+
prompt = f"""Answer the following query using the provided context and chain of thought reasoning.
213+
First break down the problem into steps, then use the context to solve each step and arrive at the final answer.
214+
If the context doesn't contain enough information to answer accurately, say so explicitly.
215+
216+
Context:
217+
{context_str}
218+
219+
Query: {query}
220+
221+
Let's think about this step by step:"""
222+
else:
223+
prompt = f"""Answer the following query using the provided context.
203224
If the context doesn't contain enough information to answer accurately,
204225
say so explicitly.
205226
@@ -225,6 +246,7 @@ def main():
225246
parser.add_argument("--store-path", default="embeddings", help="Path to the vector store")
226247
parser.add_argument("--model", default="mistralai/Mistral-7B-Instruct-v0.2", help="Model to use")
227248
parser.add_argument("--quiet", action="store_true", help="Disable verbose logging")
249+
parser.add_argument("--use-cot", action="store_true", help="Enable Chain of Thought reasoning")
228250

229251
args = parser.parse_args()
230252

@@ -241,7 +263,7 @@ def main():
241263
logger.info(f"Initializing vector store from: {args.store_path}")
242264
store = VectorStore(persist_directory=args.store_path)
243265
logger.info("Initializing local RAG agent...")
244-
agent = LocalRAGAgent(store, model_name=args.model)
266+
agent = LocalRAGAgent(store, model_name=args.model, use_cot=args.use_cot)
245267

246268
print(f"\nProcessing query: {args.query}")
247269
print("=" * 50)

agentic_rag/rag_agent.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class QueryAnalysis(BaseModel):
2121
)
2222

2323
class RAGAgent:
24-
def __init__(self, vector_store: VectorStore, openai_api_key: str):
24+
def __init__(self, vector_store: VectorStore, openai_api_key: str, use_cot: bool = False):
2525
"""Initialize RAG agent with vector store and LLM"""
2626
self.vector_store = vector_store
27+
self.use_cot = use_cot
2728
self.llm = ChatOpenAI(
2829
model="gpt-4-turbo-preview",
2930
temperature=0,
@@ -94,18 +95,30 @@ def _generate_response(self, query: str, context: List[Dict[str, Any]]) -> Dict[
9495
context_str = "\n\n".join([f"Context {i+1}:\n{item['content']}"
9596
for i, item in enumerate(context)])
9697

97-
prompt = ChatPromptTemplate.from_template(
98-
"""Answer the following query using the provided context.
99-
If the context doesn't contain enough information to answer accurately,
100-
say so explicitly.
101-
102-
Context:
103-
{context}
104-
105-
Query: {query}
106-
107-
Answer:"""
108-
)
98+
if self.use_cot:
99+
template = """Answer the following query using the provided context and chain of thought reasoning.
100+
First break down the problem into steps, then use the context to solve each step and arrive at the final answer.
101+
If the context doesn't contain enough information to answer accurately, say so explicitly.
102+
103+
Context:
104+
{context}
105+
106+
Query: {query}
107+
108+
Let's think about this step by step:"""
109+
else:
110+
template = """Answer the following query using the provided context.
111+
If the context doesn't contain enough information to answer accurately,
112+
say so explicitly.
113+
114+
Context:
115+
{context}
116+
117+
Query: {query}
118+
119+
Answer:"""
120+
121+
prompt = ChatPromptTemplate.from_template(template)
109122

110123
messages = prompt.format_messages(context=context_str, query=query)
111124
response = self.llm.invoke(messages)
@@ -119,6 +132,7 @@ def main():
119132
parser = argparse.ArgumentParser(description="Query documents using OpenAI GPT-4")
120133
parser.add_argument("--query", required=True, help="Query to process")
121134
parser.add_argument("--store-path", default="chroma_db", help="Path to the vector store")
135+
parser.add_argument("--use-cot", action="store_true", help="Enable Chain of Thought reasoning")
122136

123137
args = parser.parse_args()
124138

@@ -135,7 +149,7 @@ def main():
135149

136150
try:
137151
store = VectorStore(persist_directory=args.store_path)
138-
agent = RAGAgent(store, openai_api_key=os.getenv("OPENAI_API_KEY"))
152+
agent = RAGAgent(store, openai_api_key=os.getenv("OPENAI_API_KEY"), use_cot=args.use_cot)
139153

140154
print(f"\nProcessing query: {args.query}")
141155
print("=" * 50)

0 commit comments

Comments
 (0)