Skip to content

Commit b75b359

Browse files
authored
Merge pull request #3 from jvm123/feature/improved-whitespace-handling
More resilient regex in utils.code_utils.extract_diffs and removed re…
2 parents 4866d2a + 8c6ec66 commit b75b359

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

openevolve/utils/code_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def apply_diff(original_code: str, diff_text: str) -> str:
5353
result_lines = original_lines.copy()
5454

5555
# Extract diff blocks
56-
diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE"
57-
diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL)
56+
diff_blocks = extract_diffs(diff_text)
5857

5958
# Apply each diff block
6059
for search_text, replace_text in diff_blocks:
@@ -81,9 +80,9 @@ def extract_diffs(diff_text: str) -> List[Tuple[str, str]]:
8180
Returns:
8281
List of tuples (search_text, replace_text)
8382
"""
84-
diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE"
83+
diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE"
8584
diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL)
86-
return diff_blocks
85+
return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks]
8786

8887

8988
def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]:

tests/test_basic.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ def test_extract_diffs(self):
2323
"""Test extracting diffs from a response"""
2424
diff_text = """
2525
Let's improve this code:
26-
26+
2727
<<<<<<< SEARCH
2828
def hello():
2929
print("Hello")
3030
=======
3131
def hello():
3232
print("Hello, World!")
3333
>>>>>>> REPLACE
34-
34+
3535
Another change:
36-
36+
3737
<<<<<<< SEARCH
3838
x = 1
3939
=======
@@ -43,17 +43,25 @@ def hello():
4343

4444
diffs = extract_diffs(diff_text)
4545
self.assertEqual(len(diffs), 2)
46-
self.assertEqual(diffs[0][0].strip(), 'def hello():\n print("Hello")')
47-
self.assertEqual(diffs[0][1].strip(), 'def hello():\n print("Hello, World!")')
48-
self.assertEqual(diffs[1][0].strip(), "x = 1")
49-
self.assertEqual(diffs[1][1].strip(), "x = 2")
46+
self.assertEqual(
47+
diffs[0][0],
48+
""" def hello():
49+
print("Hello")""",
50+
)
51+
self.assertEqual(
52+
diffs[0][1],
53+
""" def hello():
54+
print("Hello, World!")""",
55+
)
56+
self.assertEqual(diffs[1][0], " x = 1")
57+
self.assertEqual(diffs[1][1], " x = 2")
5058

5159
def test_apply_diff(self):
5260
"""Test applying diffs to code"""
5361
original_code = """
5462
def hello():
5563
print("Hello")
56-
64+
5765
x = 1
5866
y = 2
5967
"""
@@ -66,7 +74,7 @@ def hello():
6674
def hello():
6775
print("Hello, World!")
6876
>>>>>>> REPLACE
69-
77+
7078
<<<<<<< SEARCH
7179
x = 1
7280
=======
@@ -77,7 +85,7 @@ def hello():
7785
expected_code = """
7886
def hello():
7987
print("Hello, World!")
80-
88+
8189
x = 2
8290
y = 2
8391
"""
@@ -86,8 +94,8 @@ def hello():
8694

8795
# Normalize whitespace for comparison
8896
self.assertEqual(
89-
result.replace(" ", "").replace("\n", ""),
90-
expected_code.replace(" ", "").replace("\n", ""),
97+
result,
98+
expected_code,
9199
)
92100

93101

0 commit comments

Comments
 (0)