Skip to content

Commit 557ab89

Browse files
authored
Merge branch 'main' into formatter-output-fix
2 parents 9ab37a7 + e1d8fe0 commit 557ab89

File tree

5 files changed

+160
-11
lines changed

5 files changed

+160
-11
lines changed

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
N_TESTS_TO_GENERATE = 2
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
11+
MIN_TESTCASE_PASSED_THRESHOLD = 6

codeflash/discovery/functions_to_optimize.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
9494
self.functions.append(
9595
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
9696
)
97-
# Continue visiting the body of the function to find nested functions
98-
self.generic_visit(node)
9997

10098
def generic_visit(self, node: ast.AST) -> None:
10199
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):

codeflash/result/critic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils import env_utils
7-
from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD
7+
from codeflash.code_utils.config_consts import (
8+
COVERAGE_THRESHOLD,
9+
MIN_IMPROVEMENT_THRESHOLD,
10+
MIN_TESTCASE_PASSED_THRESHOLD,
11+
)
812
from codeflash.models.models import TestType
913

1014
if TYPE_CHECKING:
@@ -50,7 +54,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult) -> bool
5054
for test_type in report:
5155
pass_count += report[test_type]["passed"]
5256

53-
if pass_count >= 4:
57+
if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD:
5458
return True
5559
# If only one test passed, check if it's a REPLAY_TEST
5660
return bool(pass_count == 1 and report[TestType.REPLAY_TEST]["passed"] == 1)

tests/test_critic.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ def test_generated_test_critic() -> None:
195195
timed_out=False,
196196
loop_index=1,
197197
)
198-
199-
test_results = [test_1, test_2, test_3, test_7]
198+
test_results = [test_1, test_2, test_3, test_4, test_5, test_6, test_7, test_1]
200199

201200
candidate_result = OptimizedCandidateResult(
202201
max_loop_count=5,
@@ -209,7 +208,7 @@ def test_generated_test_critic() -> None:
209208

210209
assert quantity_of_tests_critic(candidate_result)
211210

212-
test_results = [test_1, test_2, test_3, test_6, test_7]
211+
test_results = [test_1, test_2, test_3, test_6, test_7, test_1, test_4, test_1]
213212

214213
candidate_result = OptimizedCandidateResult(
215214
max_loop_count=5,
@@ -222,7 +221,7 @@ def test_generated_test_critic() -> None:
222221

223222
assert quantity_of_tests_critic(candidate_result)
224223

225-
test_results = [test_1, test_3, test_4, test_2, test_7]
224+
test_results = [test_1, test_3, test_4, test_2, test_7, test_1, test_6, test_1]
226225

227226
candidate_result = OptimizedCandidateResult(
228227
max_loop_count=5,
@@ -248,7 +247,7 @@ def test_generated_test_critic() -> None:
248247

249248
assert not quantity_of_tests_critic(candidate_result)
250249

251-
test_results = [test_1, test_2, test_3, test_4, test_5]
250+
test_results = [test_1, test_2, test_3, test_4, test_5, test_1, test_1, test_1]
252251

253252
candidate_result = OptimizedCandidateResult(
254253
max_loop_count=5,
@@ -287,7 +286,7 @@ def test_generated_test_critic() -> None:
287286

288287
assert quantity_of_tests_critic(candidate_result)
289288

290-
test_results = [test_1, test_2, test_3, test_4, test_5]
289+
test_results = [test_1, test_2, test_3, test_4, test_5, test_1, test_1, test_1]
291290

292291
candidate_result = OptimizedCandidateResult(
293292
max_loop_count=5,
@@ -328,7 +327,7 @@ def test_generated_test_critic() -> None:
328327

329328
assert not quantity_of_tests_critic(candidate_result)
330329

331-
test_results = [test_1, test_2, test_3, test_5]
330+
test_results = [test_1, test_2, test_3, test_5, test_1, test_1, test_1, test_1]
332331

333332
candidate_result = OptimizedCandidateResult(
334333
max_loop_count=5,

tests/test_function_discovery.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ def test_function_eligible_for_optimization() -> None:
3535
assert len(functions_found[Path(f.name)]) == 0
3636

3737

38+
# we want to trigger an error in the function discovery
39+
function = """def test_invalid_code():"""
40+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
41+
f.write(function)
42+
f.flush()
43+
functions_found = find_all_functions_in_file(Path(f.name))
44+
assert functions_found == {}
45+
46+
47+
48+
3849
def test_find_top_level_function_or_method():
3950
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
4051
f.write(
@@ -82,6 +93,15 @@ def non_classmethod_function(cls, name):
8293
).is_top_level
8394
# needed because this will be traced with a class_name being passed
8495

96+
# we want to write invalid code to ensure that the function discovery does not crash
97+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
98+
f.write(
99+
"""def functionA():
100+
"""
101+
)
102+
f.flush()
103+
path_obj_name = Path(f.name)
104+
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA")
85105

86106
def test_class_method_discovery():
87107
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
@@ -152,6 +172,133 @@ def functionA():
152172
assert functions[file][0].function_name == "functionA"
153173

154174

175+
def test_nested_function():
176+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
177+
f.write(
178+
"""
179+
import copy
180+
181+
def propagate_attributes(
182+
nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
183+
) -> dict[str, dict]:
184+
modified_nodes = copy.deepcopy(nodes)
185+
186+
# Build an adjacency list for faster traversal
187+
adjacency = {}
188+
for edge in edges:
189+
src = edge["source"]
190+
tgt = edge["target"]
191+
if src not in adjacency:
192+
adjacency[src] = []
193+
adjacency[src].append(tgt)
194+
195+
# Track visited nodes to avoid cycles
196+
visited = set()
197+
198+
def traverse(node_id):
199+
if node_id in visited:
200+
return
201+
visited.add(node_id)
202+
203+
# Propagate attribute from source node
204+
if (
205+
node_id != source_node_id
206+
and source_node_id in modified_nodes
207+
and attribute in modified_nodes[source_node_id]
208+
):
209+
if node_id in modified_nodes:
210+
modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
211+
attribute
212+
]
213+
214+
# Continue propagation to neighbors
215+
for neighbor in adjacency.get(node_id, []):
216+
traverse(neighbor)
217+
218+
traverse(source_node_id)
219+
return modified_nodes
220+
"""
221+
)
222+
f.flush()
223+
test_config = TestConfig(
224+
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
225+
)
226+
path_obj_name = Path(f.name)
227+
functions, functions_count = get_functions_to_optimize(
228+
optimize_all=None,
229+
replay_test=None,
230+
file=path_obj_name,
231+
test_cfg=test_config,
232+
only_get_this_function=None,
233+
ignore_paths=[Path("/bruh/")],
234+
project_root=path_obj_name.parent,
235+
module_root=path_obj_name.parent,
236+
)
237+
238+
assert len(functions) == 1
239+
assert functions_count == 1
240+
241+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
242+
f.write(
243+
"""
244+
def outer_function():
245+
def inner_function():
246+
pass
247+
248+
return inner_function
249+
"""
250+
)
251+
f.flush()
252+
test_config = TestConfig(
253+
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
254+
)
255+
path_obj_name = Path(f.name)
256+
functions, functions_count = get_functions_to_optimize(
257+
optimize_all=None,
258+
replay_test=None,
259+
file=path_obj_name,
260+
test_cfg=test_config,
261+
only_get_this_function=None,
262+
ignore_paths=[Path("/bruh/")],
263+
project_root=path_obj_name.parent,
264+
module_root=path_obj_name.parent,
265+
)
266+
267+
assert len(functions) == 1
268+
assert functions_count == 1
269+
270+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
271+
f.write(
272+
"""
273+
def outer_function():
274+
def inner_function():
275+
pass
276+
277+
def another_inner_function():
278+
pass
279+
return inner_function, another_inner_function
280+
"""
281+
)
282+
f.flush()
283+
test_config = TestConfig(
284+
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
285+
)
286+
path_obj_name = Path(f.name)
287+
functions, functions_count = get_functions_to_optimize(
288+
optimize_all=None,
289+
replay_test=None,
290+
file=path_obj_name,
291+
test_cfg=test_config,
292+
only_get_this_function=None,
293+
ignore_paths=[Path("/bruh/")],
294+
project_root=path_obj_name.parent,
295+
module_root=path_obj_name.parent,
296+
)
297+
298+
assert len(functions) == 1
299+
assert functions_count == 1
300+
301+
155302
def test_filter_files_optimized():
156303
tests_root = Path("tests").resolve()
157304
module_root = Path().resolve()

0 commit comments

Comments
 (0)