Skip to content

Commit 389753e

Browse files
fix: small change to accelerag __init__ module file, | fix: fixed import error in cota_engine thought_actions file | fix: fixed imports across tests
1 parent acfe12f commit 389753e

File tree

6 files changed

+31
-24
lines changed

6 files changed

+31
-24
lines changed

cotarag/accelerag/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070

7171
__all__ = [
7272
"Embedder",
73-
"ClaudeEmbedder",
7473
"Indexer",
7574
"Retriever",
7675
"QueryEngine",

cotarag/cota_engine/thought_actions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from accelerag.query_engines.query_engines import AnthropicEngine
2+
from cotarag.accelerag.query_engines.query_engines import AnthropicEngine
33

44

55
# Update ThoughtAction to include __str__ and __repr__ methods

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,3 @@ Issues = "https://github.com/Kernel-Dirichlet/CoTARAG/issues"
5252

5353
[tool.setuptools]
5454
packages = ["cotarag", "cotarag.cota_engine", "cotarag.accelerag"]
55-
56-
[tool.setuptools.package-data]
57-
cotarag = ["py.typed"]

tests/test_cota_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
# Add the project root to Python path
88
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
99

10-
from cota_engine.cota_engines import CoTAEngine
11-
from cota_engine.thought_actions import LLMThoughtAction
10+
from cotarag.cota_engine.cota_engines import CoTAEngine
11+
from cotarag.cota_engine.thought_actions import LLMThoughtAction
1212

1313
class TestCoTAEngine(unittest.TestCase):
1414
@classmethod

tests/test_rag_image.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
# Add the project root to Python path
1414
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15-
from accelerag.managers import RAGManager
16-
from accelerag.embedders import ImageEmbedder
17-
from accelerag.indexers import ImageIndexer
18-
from accelerag.retrievers import ImageRetriever
15+
from cotarag.accelerag.managers import RAGManager
16+
from cotarag.accelerag.embedders import ImageEmbedder
17+
from cotarag.accelerag.indexers import ImageIndexer
18+
from cotarag.accelerag.retrievers import ImageRetriever
1919

2020

2121
def generate_digits_dataset(output_dir):
@@ -78,13 +78,13 @@ def setUpClass(cls):
7878
cls.rag = RAGManager(
7979
api_key=cls.api_key,
8080
dir_to_idx=cls.digits_dir,
81-
grounding='soft',
82-
enable_cache=False, # Disable caching
83-
use_cache=False, # Disable caching
84-
logging_enabled=True,
85-
force_reindex=True,
86-
hard_grounding_prompt='prompts/hard_grounding_prompt.txt',
87-
soft_grounding_prompt='prompts/soft_grounding_prompt.txt'
81+
grounding = 'soft',
82+
enable_cache = False, # Disable caching
83+
use_cache = False, # Disable caching
84+
logging_enabled = True,
85+
force_reindex = True,
86+
hard_grounding_prompt = 'prompts/hard_grounding_prompt.txt',
87+
soft_grounding_prompt = 'prompts/soft_grounding_prompt.txt'
8888
)
8989

9090
@classmethod

tests/test_rag_text.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# Add the project root to Python path
99
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1010

11-
from accelerag.managers import RAGManager
12-
from accelerag.query_utils import create_tag_hierarchy
13-
from accelerag.query_engines.query_engines import OpenAIEngine, AnthropicEngine
11+
from cotarag.accelerag.managers import RAGManager
12+
from cotarag.accelerag.query_utils import create_tag_hierarchy
13+
from cotarag.accelerag.query_engines.query_engines import OpenAIEngine, AnthropicEngine
1414

1515
class BaseRAGTest(unittest.TestCase):
1616
"""Base class for RAG testing with common setup and test methods."""
@@ -22,7 +22,9 @@ def setUpClass(cls):
2222

2323
# Copy test data to temp directory
2424
cls.data_dir = os.path.join(cls.test_dir, 'test_data')
25-
test_data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'accelerag/arxiv_mini')
25+
test_data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)),
26+
'cotarag/accelerag/arxiv_mini')
27+
2628
shutil.copytree(test_data_path, cls.data_dir)
2729

2830
cls.db_path = os.path.join(cls.test_dir, 'test_embeddings.db.sqlite')
@@ -32,7 +34,10 @@ def setUpClass(cls):
3234

3335
# Get project root for constructing prompt file paths
3436
cls.project_root = os.path.dirname(os.path.dirname(__file__))
35-
cls.prompts_dir = os.path.join(cls.project_root,'accelerag','prompts')
37+
cls.prompts_dir = os.path.join(cls.project_root,
38+
'cotarag',
39+
'accelerag',
40+
'prompts')
3641

3742

3843
@classmethod
@@ -103,7 +108,7 @@ def _run_rag_flow_test(self, rag):
103108
print(f"Query: {similar_query}")
104109

105110
# Get chunks for similar query
106-
similar_chunks = rag.retrieve(similar_query, top_k=5)
111+
similar_chunks = rag.retrieve(similar_query, top_k = 5)
107112
similar_response = rag.generate_response(similar_query)
108113
print("\nResponse:")
109114
print("-"*50)
@@ -162,16 +167,19 @@ def setUpClass(cls):
162167
force_reindex = True,
163168
query_engine = OpenAIEngine(api_key = cls.api_key),
164169
hard_grounding_prompt = os.path.join(cls.project_root,
170+
'cotarag',
165171
'accelerag',
166172
'prompts',
167173
'hard_grounding_prompt.txt'),
168174

169175
soft_grounding_prompt=os.path.join(cls.project_root,
176+
'cotarag',
170177
'accelerag',
171178
'prompts',
172179
'soft_grounding_prompt.txt'),
173180

174181
template_path=os.path.join(cls.project_root,
182+
'cotarag',
175183
'accelerag',
176184
'web_rag_template.txt')
177185
)
@@ -216,16 +224,19 @@ def setUpClass(cls):
216224
force_reindex = True,
217225
query_engine = AnthropicEngine(api_key=cls.api_key),
218226
hard_grounding_prompt = os.path.join(cls.project_root,
227+
'cotarag',
219228
'accelerag',
220229
'prompts',
221230
'hard_grounding_prompt.txt'),
222231

223232
soft_grounding_prompt = os.path.join(cls.project_root,
233+
'cotarag',
224234
'accelerag',
225235
'prompts',
226236
'soft_grounding_prompt.txt'),
227237

228238
template_path = os.path.join(cls.project_root,
239+
'cotarag',
229240
'accelerag',
230241
'web_rag_template.txt')
231242
)

0 commit comments

Comments
 (0)