11import os
2+ from lib .utils import CodeGenSandbox
23import typer
34
45import langroid as lr
56from langroid .utils .configuration import set_global , Settings
67from langroid .utils .logging import setup_colored_logging
78
8- from TestRunner .GenericTestRunner import GenericTestRunner , InlineTestRunner , SubProcessTestRunner
9+ from TestRunner .GenericTestRunner import GenericTestRunner , SubProcessTestRunner
910
1011app = typer .Typer ()
1112setup_colored_logging ()
1213
1314
14- def generate_first_attempt (class_skeleton : str ) -> None :
15- with open (class_skeleton , "r" ) as f :
15+ def generate_first_attempt (sandbox : CodeGenSandbox ) -> None :
16+ with open (sandbox . get_sandboxed_class_path () , "r" ) as f :
1617 class_skeleton = f .read ()
1718
1819 cfg = lr .ChatAgentConfig (
@@ -30,11 +31,11 @@ def generate_first_attempt(class_skeleton: str) -> None:
3031 f"Do not say 'here is the python code'"
3132 f"Your output MUST be valid, runnable python code and NOTHING else."
3233 f"{ class_skeleton } " )
33- with open (os . path . join ( "." , "generated" , "test_class.py" ), "w+" ) as _out :
34+ with open (sandbox . get_sandboxed_class_path ( ), "w+" ) as _out :
3435 _out .write (response .content )
3536
3637
37- def generate_next_attempt (test_results : str , test_results_insights : str ) -> None :
38+ def generate_next_attempt (sandbox : CodeGenSandbox , test_results : str , test_results_insights : str ) -> None :
3839 cfg = lr .ChatAgentConfig (
3940 llm = lr .language_models .OpenAIGPTConfig (
4041 chat_model = "ollama/llama3:latest" ,
@@ -43,7 +44,7 @@ def generate_next_attempt(test_results: str, test_results_insights: str) -> None
4344 vecdb = None
4445 )
4546 agent = lr .ChatAgent (cfg )
46- with open (os . path . join ( "." , "generated" , "test_class.py" ), "r" ) as f :
47+ with open (sandbox . get_sandboxed_class_path ( ), "r" ) as f :
4748 code_snippet = f .read ()
4849
4950 prompt = f"""
@@ -65,7 +66,7 @@ def generate_next_attempt(test_results: str, test_results_insights: str) -> None
6566 Your output MUST be valid, runnable python code and NOTHING else.
6667 """
6768 response = agent .llm_response (prompt )
68- with open (os . path . join ( "." , "generated" , "test_class.py" ), "w" ) as _out :
69+ with open (sandbox . get_sandboxed_class_path ( ), "w" ) as _out :
6970 _out .write (response .content )
7071
7172
@@ -101,8 +102,8 @@ def teardown() -> None:
101102 generated_file .truncate (0 )
102103
103104
104- def chat (class_skeleton : str , test_dir : str , test_runner : GenericTestRunner , max_epochs : int = 5 ) -> None :
105- generate_first_attempt (class_skeleton )
105+ def chat (sandbox : CodeGenSandbox , test_runner : GenericTestRunner , max_epochs : int = 5 ) -> None :
106+ generate_first_attempt (sandbox )
106107 solved = False
107108 for _ in range (max_epochs ):
108109 # test_exit_code, test_results = get_test_results()
@@ -114,22 +115,44 @@ def chat(class_skeleton: str, test_dir: str, test_runner: GenericTestRunner, max
114115 break
115116 elif test_exit_code == 1 :
116117 results_insights = interpret_test_results (test_results )
117- generate_next_attempt (test_results , results_insights )
118+ generate_next_attempt (sandbox , test_results , results_insights )
118119 else :
119120 solved = True
120121 print ("There is some problem with the test suite itself." )
121122 break
122- teardown ()
123+ # teardown()
123124 if not solved :
124125 print (f"Reached the end of epoch { max_epochs } without finding a solution :(" )
125126
127+
126128@app .command ()
127129def main (
128130 debug : bool = typer .Option (False , "--debug" , "-d" , help = "debug mode" ),
129131 no_stream : bool = typer .Option (False , "--nostream" , "-ns" , help = "no streaming" ),
130132 nocache : bool = typer .Option (False , "--nocache" , "-nc" , help = "don't use cache" ),
131- class_skeleton : str = typer .Option (None , "--class-skeleton" , "-c" , help = "You must provide a class skeleton." ),
132- test_dir : str = typer .Option (os .path .join ("." , "test" ), "--test-dir" , "-t" , help = "" ),
133+ project_dir : str = typer .Argument (
134+ default = "." ,
135+ help = "The project directory that contains your tests and class skeleton. "
136+ "This directory may also have other contents. "
137+ "The directory you give here will be cloned into a 'sandbox' for the code generator to operate in."
138+ ),
139+ class_skeleton_path : str = typer .Argument (
140+ default = os .path .join ("assets" , "test_class.py" ),
141+ help = "Path to the class skeleton file, relative to project_dir."
142+ ),
143+ test_path : str = typer .Argument (
144+ default = os .path .join ("." , "test" ),
145+ help = "Path to the test file or directory, relative to project_dir."
146+ ),
147+ sandbox_path : str = typer .Option (
148+ "./build" , "--sandbox-path" , "-s" ,
149+ help = "You may optionally specify a location for the sandbox in which the code generator operates."
150+ "Default: ./build"
151+ ),
152+ max_epochs : int = typer .Option (
153+ 5 , "--max-epochs" , "-n" , help = "The maximum number of times to let the code generator try"
154+ "before giving up."
155+ )
133156) -> None :
134157 set_global (
135158 Settings (
@@ -138,15 +161,11 @@ def main(
138161 stream = not no_stream ,
139162 )
140163 )
141- assert os .path .isfile (class_skeleton ), f"The class skeleton file provided does not exist! Got { class_skeleton } "
142- assert os .path .exists (test_dir ), f"The test-dir provided does not exist! Got { test_dir } "
143-
144- tr : GenericTestRunner = SubProcessTestRunner ("" , test_dir )
145- chat (
146- class_skeleton = class_skeleton ,
147- test_dir = test_dir ,
148- test_runner = tr
149- )
164+
165+ sandbox = CodeGenSandbox (project_dir , class_skeleton_path , test_path , sandbox_path )
166+ sandbox .init_sandbox ()
167+ tr : GenericTestRunner = SubProcessTestRunner (sandbox )
168+ chat (sandbox , tr , max_epochs = max_epochs )
150169
151170
152171if __name__ == "__main__" :
0 commit comments