Skip to content

Commit 62ab51c

Browse files
fix: thought_action tweak + cota_engine test cleanup
1 parent f723580 commit 62ab51c

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tests/test_cota_engine.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import sys
44
import tempfile
55
import shutil
6-
6+
import json
77
# Add the project root to Python path
88
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
99

1010
from cotarag.cota_engine.cota_engines import CoTAEngine
11+
from cotarag.accelerag.query_engines import AnthropicEngine,OpenAIEngine
1112
from cotarag.cota_engine.thought_actions import LLMThoughtAction
1213

1314
class TestCoTAEngine(unittest.TestCase):
@@ -57,13 +58,18 @@ def test_llm_thought_action_chain(self):
5758
class EvaluateCodebaseAction(LLMThoughtAction):
5859
def action(self, thought_output):
5960
# Write the summary to codebase_summary
60-
with open('codebase_summary', 'w') as f:
61-
f.write(thought_output)
62-
return thought_output
61+
f = open('codebase_summary','w')
62+
f.write(thought_output['thought'])
63+
f.close()
64+
return {'action': 'wrote summary'}
6365

6466
# Reasoning: Run the first step to get the summary
65-
eval_action = EvaluateCodebaseAction(api_key=self.api_key)
66-
summary = eval_action.thought(code_prompt['action'])
67+
if os.environ['CLAUDE_API_KEY'] is not None:
68+
query_engine = AnthropicEngine(api_key = os.environ['CLAUDE_API_KEY'])
69+
else:
70+
query_engine = OpenAIEngine(api_key = os.environ['OPENAI_API_KEY'])
71+
eval_action = EvaluateCodebaseAction(query_engine = query_engine)
72+
summary = eval_action.thought(code_prompt)
6773
eval_action.action(summary)
6874

6975
# Reasoning: Read the summary from the file for the next step
@@ -78,11 +84,11 @@ class AnalyzeCodebaseAction(LLMThoughtAction):
7884
def action(self, thought_output):
7985
# Write the improvements to codebase_TODOs
8086
with open('codebase_TODOs', 'w') as f:
81-
f.write(thought_output)
87+
f.write(thought_output['thought'])
8288
return thought_output
8389

8490
# Reasoning: Run the second step to get the improvements
85-
analyze_action = AnalyzeCodebaseAction(api_key=self.api_key)
91+
analyze_action = AnalyzeCodebaseAction(query_engine = query_engine)
8692
improvements = analyze_action.thought(improvement_prompt)
8793
analyze_action.action(improvements)
8894

@@ -107,7 +113,7 @@ def action(self, thought_output):
107113

108114
# Reasoning: Print the ASCII chain of thought-action steps using __str__ methods
109115
print("ASCII Chain of Thought-Action (CoTA) Steps:")
110-
print(f"input -> {EvaluateCodebaseAction().__str__()} -> {AnalyzeCodebaseAction().__str__()}")
116+
print(f"input -> {eval_action.__str__()} -> {analyze_action.__str__()}")
111117

112118
if __name__ == '__main__':
113119
unittest.main()

0 commit comments

Comments
 (0)