Skip to content

Commit c4c43eb

Browse files
committed
test(rag): add Chain of Thought comparison tests - Create test suite to compare standard vs CoT responses - Add diverse test cases for different reasoning types - Implement rich console output for clear comparison
1 parent e26061c commit c4c43eb

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

agentic_rag/tests/test_cot.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import sys
2+
import os
3+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4+
5+
from local_rag_agent import LocalRAGAgent
6+
from rag_agent import RAGAgent
7+
from store import VectorStore
8+
from dotenv import load_dotenv
9+
import yaml
10+
import argparse
11+
from rich.console import Console
12+
from rich.panel import Panel
13+
from rich.table import Table
14+
15+
console = Console()
16+
17+
def load_config():
18+
"""Load configuration from config.yaml and .env"""
19+
try:
20+
with open('config.yaml', 'r') as f:
21+
config = yaml.safe_load(f)
22+
load_dotenv()
23+
return {
24+
'hf_token': config.get('HUGGING_FACE_HUB_TOKEN'),
25+
'openai_key': os.getenv('OPENAI_API_KEY')
26+
}
27+
except Exception as e:
28+
console.print(f"[red]Error loading configuration: {str(e)}")
29+
sys.exit(1)
30+
31+
def compare_responses(agent, query: str, description: str):
32+
"""Compare standard vs CoT responses for the same query"""
33+
console.print(f"\n[bold cyan]Test Case: {description}")
34+
console.print(Panel(f"Query: {query}", style="yellow"))
35+
36+
# Standard response
37+
agent.use_cot = False
38+
standard_response = agent.process_query(query)
39+
console.print(Panel(
40+
"[bold]Standard Response:[/bold]\n" + standard_response["answer"],
41+
title="Without Chain of Thought",
42+
style="blue"
43+
))
44+
45+
# CoT response
46+
agent.use_cot = True
47+
cot_response = agent.process_query(query)
48+
console.print(Panel(
49+
"[bold]Chain of Thought Response:[/bold]\n" + cot_response["answer"],
50+
title="With Chain of Thought",
51+
style="green"
52+
))
53+
54+
def main():
55+
parser = argparse.ArgumentParser(description="Compare standard vs Chain of Thought prompting")
56+
parser.add_argument("--model", choices=['local', 'openai'], default='local',
57+
help="Choose between local Mistral model or OpenAI")
58+
args = parser.parse_args()
59+
60+
config = load_config()
61+
store = VectorStore(persist_directory="chroma_db")
62+
63+
# Initialize appropriate agent
64+
if args.model == 'local':
65+
if not config['hf_token']:
66+
console.print("[red]Error: HuggingFace token not found in config.yaml")
67+
sys.exit(1)
68+
agent = LocalRAGAgent(store)
69+
model_name = "Mistral-7B"
70+
else:
71+
if not config['openai_key']:
72+
console.print("[red]Error: OpenAI API key not found in .env")
73+
sys.exit(1)
74+
agent = RAGAgent(store, openai_api_key=config['openai_key'])
75+
model_name = "GPT-4"
76+
77+
console.print(f"\n[bold]Testing {model_name} Responses[/bold]")
78+
console.print("=" * 80)
79+
80+
# Test cases that highlight CoT benefits
81+
test_cases = [
82+
{
83+
"query": "A train travels at 60 mph for 2.5 hours, then at 45 mph for 1.5 hours. What's the total distance covered?",
84+
"description": "Multi-step math problem"
85+
},
86+
{
87+
"query": "Compare and contrast REST and GraphQL APIs, considering their strengths and use cases.",
88+
"description": "Complex comparison requiring structured analysis"
89+
},
90+
{
91+
"query": "If a tree falls in a forest and no one is around to hear it, does it make a sound? Explain your reasoning.",
92+
"description": "Philosophical question requiring detailed reasoning"
93+
}
94+
]
95+
96+
for test_case in test_cases:
97+
try:
98+
compare_responses(agent, test_case["query"], test_case["description"])
99+
except Exception as e:
100+
console.print(f"[red]Error in test case '{test_case['description']}': {str(e)}")
101+
102+
console.print("\n[bold green]Testing complete!")
103+
104+
if __name__ == "__main__":
105+
main()

0 commit comments

Comments
 (0)