33import sys
44import tempfile
55import shutil
6-
6+ import json
77# Add the project root to Python path
88sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
99
1010from cotarag .cota_engine .cota_engines import CoTAEngine
11+ from cotarag .accelerag .query_engines import AnthropicEngine ,OpenAIEngine
1112from cotarag .cota_engine .thought_actions import LLMThoughtAction
1213
1314class TestCoTAEngine (unittest .TestCase ):
@@ -57,13 +58,18 @@ def test_llm_thought_action_chain(self):
5758 class EvaluateCodebaseAction (LLMThoughtAction ):
5859 def action (self , thought_output ):
5960 # Write the summary to codebase_summary
60- with open ('codebase_summary' , 'w' ) as f :
61- f .write (thought_output )
62- return thought_output
61+ f = open ('codebase_summary' ,'w' )
62+ f .write (thought_output ['thought' ])
63+ f .close ()
64+ return {'action' : 'wrote summary' }
6365
6466 # Reasoning: Run the first step to get the summary
65- eval_action = EvaluateCodebaseAction (api_key = self .api_key )
66- summary = eval_action .thought (code_prompt ['action' ])
67+ if os .environ ['CLAUDE_API_KEY' ] is not None :
68+ query_engine = AnthropicEngine (api_key = os .environ ['CLAUDE_API_KEY' ])
69+ else :
70+ query_engine = OpenAIEngine (api_key = os .environ ['OPENAI_API_KEY' ])
71+ eval_action = EvaluateCodebaseAction (query_engine = query_engine )
72+ summary = eval_action .thought (code_prompt )
6773 eval_action .action (summary )
6874
6975 # Reasoning: Read the summary from the file for the next step
@@ -78,11 +84,11 @@ class AnalyzeCodebaseAction(LLMThoughtAction):
7884 def action (self , thought_output ):
7985 # Write the improvements to codebase_TODOs
8086 with open ('codebase_TODOs' , 'w' ) as f :
81- f .write (thought_output )
87+ f .write (thought_output [ 'thought' ] )
8288 return thought_output
8389
8490 # Reasoning: Run the second step to get the improvements
85- analyze_action = AnalyzeCodebaseAction (api_key = self . api_key )
91+ analyze_action = AnalyzeCodebaseAction (query_engine = query_engine )
8692 improvements = analyze_action .thought (improvement_prompt )
8793 analyze_action .action (improvements )
8894
@@ -107,7 +113,7 @@ def action(self, thought_output):
107113
108114 # Reasoning: Print the ASCII chain of thought-action steps using __str__ methods
109115 print ("ASCII Chain of Thought-Action (CoTA) Steps:" )
110- print (f"input -> { EvaluateCodebaseAction () .__str__ ()} -> { AnalyzeCodebaseAction () .__str__ ()} " )
116+ print (f"input -> { eval_action .__str__ ()} -> { analyze_action .__str__ ()} " )
111117
112118if __name__ == '__main__' :
113119 unittest .main ()
0 commit comments