Skip to content

Commit 3de7373

Browse files
committed
moved the num_discovered_tests calculation inside the discover_unit_tests
1 parent 889ca52 commit 3de7373

File tree

5 files changed

+30
-28
lines changed

5 files changed

+30
-28
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def discover_unit_tests(
305305
cfg: TestConfig,
306306
discover_only_these_tests: list[Path] | None = None,
307307
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
308-
) -> dict[str, list[FunctionCalledInTest]]:
308+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
309309
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
310310
strategy = framework_strategies.get(cfg.test_framework, None)
311311
if not strategy:
@@ -317,7 +317,9 @@ def discover_unit_tests(
317317
if file_to_funcs_to_optimize:
318318
functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list]
319319

320-
return strategy(cfg, discover_only_these_tests, functions_to_optimize)
320+
function_to_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize)
321+
num_discovered_tests = sum([len(value) for value in function_to_tests.values()])
322+
return function_to_tests, num_discovered_tests
321323

322324

323325
def discover_tests_pytest(

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
268268
def get_all_replay_test_functions(
269269
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
270270
) -> dict[Path, list[FunctionToOptimize]]:
271-
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
271+
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
272272
# Get the absolute file paths for each function, excluding class name if present
273273
filtered_valid_functions = defaultdict(list)
274274
file_to_functions_map = defaultdict(list)

codeflash/optimization/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ def run(self) -> None:
162162

163163
console.rule()
164164
start_time = time.time()
165-
function_to_tests = discover_unit_tests(self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize)
166-
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
165+
function_to_tests, num_discovered_tests = discover_unit_tests(
166+
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
167+
)
167168
console.rule()
168169
logger.info(
169170
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"

codeflash/verification/concolic_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def generate_concolic_tests(
7878
test_framework=args.test_framework,
7979
pytest_cmd=args.pytest_cmd,
8080
)
81-
function_to_concolic_tests = discover_unit_tests(concolic_test_cfg)
82-
num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()])
81+
function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg)
8382
logger.info(
8483
f"Created {num_discovered_concolic_tests} "
8584
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "

tests/test_unit_test_discovery.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_unit_test_discovery_pytest():
2020
test_framework="pytest",
2121
tests_project_rootdir=tests_path.parent,
2222
)
23-
tests = discover_unit_tests(test_config)
23+
tests, _ = discover_unit_tests(test_config)
2424
assert len(tests) > 0
2525

2626

@@ -33,7 +33,7 @@ def test_benchmark_test_discovery_pytest():
3333
test_framework="pytest",
3434
tests_project_rootdir=tests_path.parent,
3535
)
36-
tests = discover_unit_tests(test_config)
36+
tests, _ = discover_unit_tests(test_config)
3737
assert len(tests) == 1 # Should not discover benchmark tests
3838

3939

@@ -47,7 +47,7 @@ def test_unit_test_discovery_unittest():
4747
tests_project_rootdir=project_path.parent,
4848
)
4949
os.chdir(project_path)
50-
tests = discover_unit_tests(test_config)
50+
tests, _ = discover_unit_tests(test_config)
5151
# assert len(tests) > 0
5252
# Unittest discovery within a pytest environment does not work
5353

@@ -85,7 +85,7 @@ def sorter(arr):
8585
)
8686

8787
# Discover tests
88-
tests = discover_unit_tests(test_config)
88+
tests, _ = discover_unit_tests(test_config)
8989
assert len(tests) == 1
9090
assert 'bubble_sort.sorter' in tests
9191
assert len(tests['bubble_sort.sorter']) == 2
@@ -124,7 +124,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
124124
)
125125

126126
# Discover tests
127-
discovered_tests = discover_unit_tests(test_config)
127+
discovered_tests, _ = discover_unit_tests(test_config)
128128

129129
# Check if the dummy test file is discovered
130130
assert len(discovered_tests) == 1
@@ -197,7 +197,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
197197
)
198198

199199
# Discover tests
200-
discovered_tests = discover_unit_tests(test_config)
200+
discovered_tests, _ = discover_unit_tests(test_config)
201201

202202
# Check if the test files at all levels are discovered
203203
assert len(discovered_tests) == 3
@@ -287,7 +287,7 @@ def test_discover_tests_pytest_dirs():
287287
)
288288

289289
# Discover tests
290-
discovered_tests = discover_unit_tests(test_config)
290+
discovered_tests, _ = discover_unit_tests(test_config)
291291

292292
# Check if the test files at all levels are discovered
293293
assert len(discovered_tests) == 4
@@ -333,7 +333,7 @@ def test_discover_tests_pytest_with_class():
333333
)
334334

335335
# Discover tests
336-
discovered_tests = discover_unit_tests(test_config)
336+
discovered_tests, _ = discover_unit_tests(test_config)
337337

338338
# Check if the test class and method are discovered
339339
assert len(discovered_tests) == 1
@@ -371,7 +371,7 @@ def test_discover_tests_pytest_with_double_nested_directories():
371371
)
372372

373373
# Discover tests
374-
discovered_tests = discover_unit_tests(test_config)
374+
discovered_tests, _ = discover_unit_tests(test_config)
375375

376376
# Check if the test class and method are discovered
377377
assert len(discovered_tests) == 1
@@ -421,7 +421,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
421421
)
422422

423423
# Discover tests
424-
discovered_tests = discover_unit_tests(test_config)
424+
discovered_tests, _ = discover_unit_tests(test_config)
425425

426426
# Check if the test file is discovered and associated with the code file
427427
assert len(discovered_tests) == 1
@@ -460,7 +460,7 @@ def test_discover_tests_pytest_with_nested_class():
460460
)
461461

462462
# Discover tests
463-
discovered_tests = discover_unit_tests(test_config)
463+
discovered_tests, _ = discover_unit_tests(test_config)
464464

465465
# Check if the test for the nested class method is discovered
466466
assert len(discovered_tests) == 1
@@ -500,7 +500,7 @@ def test_discover_tests_pytest_separate_moduledir():
500500
)
501501

502502
# Discover tests
503-
discovered_tests = discover_unit_tests(test_config)
503+
discovered_tests, _ = discover_unit_tests(test_config)
504504

505505
# Check if the test for the nested class method is discovered
506506
assert len(discovered_tests) == 1
@@ -542,7 +542,7 @@ def test_add(self):
542542
)
543543

544544
# Discover tests
545-
discovered_tests = discover_unit_tests(test_config)
545+
discovered_tests, _ = discover_unit_tests(test_config)
546546

547547
# Verify the unittest was discovered
548548
assert len(discovered_tests) == 1
@@ -609,7 +609,7 @@ def test_add(self):
609609
)
610610

611611
# Discover tests
612-
discovered_tests = discover_unit_tests(test_config)
612+
discovered_tests, _ = discover_unit_tests(test_config)
613613

614614
# Verify the unittest was discovered
615615
assert len(discovered_tests) == 2
@@ -654,7 +654,7 @@ def _test_add(self): # Private test method should not be discovered
654654
)
655655

656656
# Discover tests
657-
discovered_tests = discover_unit_tests(test_config)
657+
discovered_tests, _ = discover_unit_tests(test_config)
658658

659659
# Verify no tests were discovered
660660
assert len(discovered_tests) == 0
@@ -706,7 +706,7 @@ def test_add_with_parameters(self):
706706
)
707707

708708
# Discover tests
709-
discovered_tests = discover_unit_tests(test_config)
709+
discovered_tests, _ = discover_unit_tests(test_config)
710710

711711
# Verify the unittest was discovered
712712
assert len(discovered_tests) == 1
@@ -788,7 +788,7 @@ def test_add_mixed(self, name, a, b, expected):
788788
)
789789

790790
# Discover tests
791-
discovered_tests = discover_unit_tests(test_config)
791+
discovered_tests, _ = discover_unit_tests(test_config)
792792

793793
# Verify the basic structure
794794
assert len(discovered_tests) == 2 # Should have tests for both add and multiply
@@ -1069,7 +1069,7 @@ def test_other():
10691069
)
10701070

10711071
# Test without filtering
1072-
all_tests = discover_unit_tests(test_config)
1072+
all_tests, _ = discover_unit_tests(test_config)
10731073
assert len(all_tests) == 2 # Should find both functions
10741074

10751075
# Test with filtering - create mock FunctionToOptimize objects
@@ -1079,7 +1079,7 @@ def test_other():
10791079
mock_function.function_name = "target_function"
10801080
mock_function.parents = [] # No parent classes
10811081

1082-
filtered_tests = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [mock_function]})
1082+
filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [mock_function]})
10831083
assert len(filtered_tests) >= 1
10841084
assert "mycode.target_function" in filtered_tests
10851085

@@ -1212,7 +1212,7 @@ def test_unrelated():
12121212
)
12131213

12141214
# Test without filtering
1215-
all_tests = discover_unit_tests(test_config)
1215+
all_tests, _ = discover_unit_tests(test_config)
12161216
assert len(all_tests) == 2 # Should find both functions
12171217

12181218
# Test with filtering - create mock FunctionToOptimize objects
@@ -1222,7 +1222,7 @@ def test_unrelated():
12221222
mock_function.function_name = "target_function"
12231223
mock_function.parents = [] # No parent classes
12241224

1225-
filtered_tests = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [mock_function]})
1225+
filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [mock_function]})
12261226
# Should filter out the unrelated test since it imports from a different module
12271227
assert len(filtered_tests) == 1
12281228
assert "target_module.target_function" in filtered_tests

0 commit comments

Comments
 (0)