Skip to content

Commit 10730a5

Browse files
committed
Reduce to two agents, one for both cases of codegen
1 parent 3f9d868 commit 10730a5

File tree

2 files changed

+49
-46
lines changed

2 files changed

+49
-46
lines changed

lib/agents.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@ class GenericAgent(metaclass=ABCMeta):
1010
def respond(self, prompt: str, *args, **kwargs) -> str: pass
1111

1212

13-
class FirstAttemptAgent(GenericAgent):
13+
class CodeGenAgent(GenericAgent):
14+
15+
sandbox: CodeGenSandbox
16+
config: ChatAgentConfig
17+
agent: ChatAgent
18+
class_skeleton: str
19+
previous_code_attempt: str
20+
latest_test_result: str
21+
latest_test_result_interpretation: str
1422

1523
def __init__(self, sandbox: CodeGenSandbox, config: ChatAgentConfig):
1624
self.sandbox = sandbox
@@ -21,8 +29,17 @@ def __init__(self, sandbox: CodeGenSandbox, config: ChatAgentConfig):
2129
self.agent = ChatAgent(config)
2230

2331
def respond(self, prompt: str, *args, **kwargs) -> str:
24-
response = self.agent.llm_response()
32+
response = self.agent.llm_response(prompt)
2533
with open(self.sandbox.get_sandboxed_class_path(), "w+") as _out:
2634
_out.write(response.content)
2735

2836
return response.content
37+
38+
def set_previous_code_attempt(self, attempt: str) -> None:
39+
self.previous_code_attempt = attempt
40+
41+
def set_latest_test_result(self, tr: str) -> None:
42+
self.latest_test_result = tr
43+
44+
def set_latest_test_result_interpretation(self, interpretation: str) -> None:
45+
self.latest_test_result_interpretation = interpretation

main.py

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from langroid import ChatAgentConfig
44

55
from lib.utils import CodeGenSandbox
6-
from lib.agents import FirstAttemptAgent
6+
from lib.agents import CodeGenAgent, GenericAgent
77
import typer
88

99
import langroid as lr
@@ -16,43 +16,6 @@
1616
setup_colored_logging()
1717

1818

19-
def generate_next_attempt(sandbox: CodeGenSandbox, test_results: str, test_results_insights: str) -> str:
20-
cfg = lr.ChatAgentConfig(
21-
llm=lr.language_models.OpenAIGPTConfig(
22-
chat_model="ollama/llama3.1:latest",
23-
),
24-
vecdb=None
25-
)
26-
agent = lr.ChatAgent(cfg)
27-
with open(sandbox.get_sandboxed_class_path(), "r") as f:
28-
code_snippet = f.read()
29-
30-
prompt = f"""
31-
You are an expert at writing Python code.
32-
Consider the following code, and test results.
33-
Here is the code:
34-
{code_snippet}
35-
Here are the test results:
36-
{test_results}
37-
In addition, you may consider these insights about the test results when coming up with your solution:
38-
{test_results_insights}
39-
Update the code so that the tests will pass.
40-
Your output MUST contain all the same classes and methods as the input code.
41-
Do NOT add any other methods or commentary.
42-
Your response should be ONLY the python code.
43-
Do not say 'here is the python code'
44-
Do not surround your response with quotes or backticks.
45-
DO NOT EVER USE ``` in your output.
46-
Your response should NEVER start or end with ```
47-
Your output MUST be valid, runnable python code and NOTHING else.
48-
"""
49-
response = agent.llm_response(prompt)
50-
with open(sandbox.get_sandboxed_class_path(), "w") as _out:
51-
_out.write(response.content)
52-
53-
return response.content
54-
55-
5619
def interpret_test_results(results: str, code: str) -> str:
5720
cfg = lr.ChatAgentConfig(
5821
llm=lr.language_models.OpenAIGPTConfig(
@@ -88,11 +51,11 @@ def teardown() -> None:
8851

8952
def chat(
9053
sandbox: CodeGenSandbox,
91-
first_attempt_agent: FirstAttemptAgent,
54+
code_gen_agent: CodeGenAgent,
9255
test_runner: GenericTestRunner,
9356
max_epochs: int = 5
9457
) -> None:
95-
code_attempt = first_attempt_agent.respond(
58+
code_attempt = code_gen_agent.respond(
9659
prompt=f"""
9760
You are an expert at writing Python code.
9861
Fill in the following class skeleton.
@@ -102,7 +65,7 @@ def chat(
10265
Do not surround your response with quotes or backticks.
10366
DO NOT EVER USE ``` in your output.
10467
Your output MUST be valid, runnable python code and NOTHING else.
105-
{first_attempt_agent.class_skeleton}
68+
{code_gen_agent.class_skeleton}
10669
"""
10770
)
10871
solved = False
@@ -116,7 +79,30 @@ def chat(
11679
break
11780
else:
11881
results_insights = interpret_test_results(test_results, code_attempt)
119-
code_attempt = generate_next_attempt(sandbox, test_results, results_insights)
82+
code_gen_agent.set_previous_code_attempt(code_attempt)
83+
code_gen_agent.set_latest_test_result(test_results)
84+
code_gen_agent.set_latest_test_result_interpretation(results_insights)
85+
code_attempt = code_gen_agent.respond(
86+
prompt=f"""
87+
You are an expert at writing Python code.
88+
Consider the following code, and test results.
89+
Here is the code:
90+
{code_gen_agent.previous_code_attempt}
91+
Here are the test results:
92+
{code_gen_agent.latest_test_result}
93+
In addition, you may consider these insights about the test results when coming up with your solution:
94+
{code_gen_agent.latest_test_result_interpretation}
95+
Update the code so that the tests will pass.
96+
Your output MUST contain all the same classes and methods as the input code.
97+
Do NOT add any other methods or commentary.
98+
Your response should be ONLY the python code.
99+
Do not say 'here is the python code'
100+
Do not surround your response with quotes or backticks.
101+
DO NOT EVER USE ``` in your output.
102+
Your response should NEVER start or end with ```
103+
Your output MUST be valid, runnable python code and NOTHING else.
104+
"""
105+
)
120106
# else:
121107
# solved = True
122108
# print("There is some problem with the test suite itself.")
@@ -172,9 +158,9 @@ def main(
172158

173159
sandbox = CodeGenSandbox(project_dir, class_skeleton_path, test_path, sandbox_path)
174160
sandbox.init_sandbox()
175-
fa: FirstAttemptAgent = FirstAttemptAgent(sandbox, llama3)
161+
code_generator: GenericAgent = CodeGenAgent(sandbox, llama3)
176162
tr: GenericTestRunner = SubProcessTestRunner(sandbox)
177-
chat(sandbox, fa, tr, max_epochs=max_epochs)
163+
chat(sandbox, code_generator, tr, max_epochs=max_epochs)
178164

179165

180166
if __name__ == "__main__":

0 commit comments

Comments
 (0)