Skip to content

Commit 25638a5

Browse files
authored
Merge pull request #57 from jvm123/feat-unit-tests
Restructured unit tests into separate files, improved TestConfigValid…
2 parents 03f9547 + a81ced6 commit 25638a5

File tree

6 files changed

+273
-221
lines changed

6 files changed

+273
-221
lines changed

tests/test_basic.py

Lines changed: 0 additions & 219 deletions
This file was deleted.

tests/test_code_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Tests for code utilities in openevolve.utils.code_utils
3+
"""
4+
5+
import unittest
6+
from openevolve.utils.code_utils import apply_diff, extract_diffs
7+
8+
9+
class TestCodeUtils(unittest.TestCase):
10+
"""Tests for code utilities"""
11+
12+
def test_extract_diffs(self):
13+
"""Test extracting diffs from a response"""
14+
diff_text = """
15+
Let's improve this code:
16+
17+
<<<<<<< SEARCH
18+
def hello():
19+
print("Hello")
20+
=======
21+
def hello():
22+
print("Hello, World!")
23+
>>>>>>> REPLACE
24+
25+
Another change:
26+
27+
<<<<<<< SEARCH
28+
x = 1
29+
=======
30+
x = 2
31+
>>>>>>> REPLACE
32+
"""
33+
34+
diffs = extract_diffs(diff_text)
35+
self.assertEqual(len(diffs), 2)
36+
self.assertEqual(
37+
diffs[0][0],
38+
""" def hello():
39+
print(\"Hello\")""",
40+
)
41+
self.assertEqual(
42+
diffs[0][1],
43+
""" def hello():
44+
print(\"Hello, World!\")""",
45+
)
46+
self.assertEqual(diffs[1][0], " x = 1")
47+
self.assertEqual(diffs[1][1], " x = 2")
48+
49+
def test_apply_diff(self):
50+
"""Test applying diffs to code"""
51+
original_code = """
52+
def hello():
53+
print("Hello")
54+
55+
x = 1
56+
y = 2
57+
"""
58+
59+
diff_text = """
60+
<<<<<<< SEARCH
61+
def hello():
62+
print("Hello")
63+
=======
64+
def hello():
65+
print("Hello, World!")
66+
>>>>>>> REPLACE
67+
68+
<<<<<<< SEARCH
69+
x = 1
70+
=======
71+
x = 2
72+
>>>>>>> REPLACE
73+
"""
74+
75+
expected_code = """
76+
def hello():
77+
print("Hello, World!")
78+
79+
x = 2
80+
y = 2
81+
"""
82+
83+
result = apply_diff(original_code, diff_text)
84+
85+
# Normalize whitespace for comparison
86+
self.assertEqual(
87+
result,
88+
expected_code,
89+
)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

tests/test_database.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Tests for ProgramDatabase in openevolve.database
3+
"""
4+
5+
import unittest
6+
from openevolve.config import Config
7+
from openevolve.database import Program, ProgramDatabase
8+
9+
10+
class TestProgramDatabase(unittest.TestCase):
11+
"""Tests for program database"""
12+
13+
def setUp(self):
14+
"""Set up test database"""
15+
config = Config()
16+
config.database.in_memory = True
17+
self.db = ProgramDatabase(config.database)
18+
19+
def test_add_and_get(self):
20+
"""Test adding and retrieving a program"""
21+
program = Program(
22+
id="test1",
23+
code="def test(): pass",
24+
language="python",
25+
metrics={"score": 0.5},
26+
)
27+
28+
self.db.add(program)
29+
30+
retrieved = self.db.get("test1")
31+
self.assertIsNotNone(retrieved)
32+
self.assertEqual(retrieved.id, "test1")
33+
self.assertEqual(retrieved.code, "def test(): pass")
34+
self.assertEqual(retrieved.metrics["score"], 0.5)
35+
36+
def test_get_best_program(self):
37+
"""Test getting the best program"""
38+
program1 = Program(
39+
id="test1",
40+
code="def test1(): pass",
41+
language="python",
42+
metrics={"score": 0.5},
43+
)
44+
45+
program2 = Program(
46+
id="test2",
47+
code="def test2(): pass",
48+
language="python",
49+
metrics={"score": 0.7},
50+
)
51+
52+
self.db.add(program1)
53+
self.db.add(program2)
54+
55+
best = self.db.get_best_program()
56+
self.assertIsNotNone(best)
57+
self.assertEqual(best.id, "test2")
58+
59+
def test_sample(self):
60+
"""Test sampling from the database"""
61+
program1 = Program(
62+
id="test1",
63+
code="def test1(): pass",
64+
language="python",
65+
metrics={"score": 0.5},
66+
)
67+
68+
program2 = Program(
69+
id="test2",
70+
code="def test2(): pass",
71+
language="python",
72+
metrics={"score": 0.7},
73+
)
74+
75+
self.db.add(program1)
76+
self.db.add(program2)
77+
78+
parent, inspirations = self.db.sample()
79+
80+
self.assertIsNotNone(parent)
81+
self.assertIn(parent.id, ["test1", "test2"])
82+
83+
84+
if __name__ == "__main__":
85+
unittest.main()

0 commit comments

Comments
 (0)