Skip to content

Commit 620d0cd

Browse files
committed
go
1 parent b4ab00b commit 620d0cd

File tree

5 files changed

+29
-34
lines changed

5 files changed

+29
-34
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def close(self) -> None:
139139

140140
def discover_unit_tests(
141141
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
142-
) -> dict[str, list[FunctionCalledInTest]]:
142+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
143143
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
144144
strategy = framework_strategies.get(cfg.test_framework, None)
145145
if not strategy:
@@ -151,7 +151,7 @@ def discover_unit_tests(
151151

152152
def discover_tests_pytest(
153153
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
154-
) -> dict[Path, list[FunctionCalledInTest]]:
154+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
155155
tests_root = cfg.tests_root
156156
project_root = cfg.project_root_path
157157

@@ -225,7 +225,7 @@ def discover_tests_pytest(
225225

226226
def discover_tests_unittest(
227227
cfg: TestConfig, discover_only_these_tests: list[str] | None = None
228-
) -> dict[Path, list[FunctionCalledInTest]]:
228+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
229229
tests_root: Path = cfg.tests_root
230230
loader: unittest.TestLoader = unittest.TestLoader()
231231
tests: unittest.TestSuite = loader.discover(str(tests_root))
@@ -290,15 +290,14 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
290290

291291
def process_test_files(
292292
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig
293-
) -> dict[str, list[FunctionCalledInTest]]:
293+
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
294294
import jedi
295295

296296
project_root_path = cfg.project_root_path
297297
test_framework = cfg.test_framework
298298

299299
function_to_test_map = defaultdict(set)
300300
jedi_project = jedi.Project(path=project_root_path)
301-
goto_cache = {}
302301
tests_cache = TestsCache()
303302

304303
with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as (
@@ -412,13 +411,8 @@ def process_test_files(
412411
if scope not in test_functions_by_name:
413412
continue
414413

415-
cache_key = (name.full_name, name.module_name)
416414
try:
417-
if cache_key in goto_cache:
418-
definition = goto_cache[cache_key]
419-
else:
420-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
421-
goto_cache[cache_key] = definition
415+
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
422416
except Exception as e:
423417
logger.debug(str(e))
424418
continue
@@ -476,4 +470,6 @@ def process_test_files(
476470
progress.advance(task_id)
477471

478472
tests_cache.close()
479-
return {function: list(tests) for function, tests in function_to_test_map.items()}
473+
function_to_tests_dict = {function: list(tests) for function, tests in function_to_test_map.items()}
474+
num_discovered_tests = sum(len(tests) for tests in function_to_tests_dict.values())
475+
return function_to_tests_dict, num_discovered_tests

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
268268
def get_all_replay_test_functions(
269269
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
270270
) -> dict[Path, list[FunctionToOptimize]]:
271-
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
271+
function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
272272
# Get the absolute file paths for each function, excluding class name if present
273273
filtered_valid_functions = defaultdict(list)
274274
file_to_functions_map = defaultdict(list)

codeflash/optimization/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def run(self) -> None:
162162

163163
console.rule()
164164
start_time = time.time()
165-
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
166-
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
165+
function_to_tests: dict[str, list[FunctionCalledInTest]]
166+
function_to_tests, num_discovered_tests = discover_unit_tests(self.test_cfg)
167167
console.rule()
168168
logger.info(
169169
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"

codeflash/verification/concolic_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def generate_concolic_tests(
7878
test_framework=args.test_framework,
7979
pytest_cmd=args.pytest_cmd,
8080
)
81-
function_to_concolic_tests = discover_unit_tests(concolic_test_cfg)
82-
num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()])
81+
function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg)
8382
logger.info(
8483
f"Created {num_discovered_concolic_tests} "
8584
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "

tests/test_unit_test_discovery.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_unit_test_discovery_pytest():
1515
test_framework="pytest",
1616
tests_project_rootdir=tests_path.parent,
1717
)
18-
tests = discover_unit_tests(test_config)
18+
tests, _ = discover_unit_tests(test_config)
1919
assert len(tests) > 0
2020

2121

@@ -28,7 +28,7 @@ def test_benchmark_test_discovery_pytest():
2828
test_framework="pytest",
2929
tests_project_rootdir=tests_path.parent,
3030
)
31-
tests = discover_unit_tests(test_config)
31+
tests, _ = discover_unit_tests(test_config)
3232
assert len(tests) == 1 # Should not discover benchmark tests
3333

3434

@@ -42,7 +42,7 @@ def test_unit_test_discovery_unittest():
4242
tests_project_rootdir=project_path.parent,
4343
)
4444
os.chdir(project_path)
45-
tests = discover_unit_tests(test_config)
45+
tests, _ = discover_unit_tests(test_config)
4646
# assert len(tests) > 0
4747
# Unittest discovery within a pytest environment does not work
4848

@@ -80,7 +80,7 @@ def sorter(arr):
8080
)
8181

8282
# Discover tests
83-
tests = discover_unit_tests(test_config)
83+
tests, _ = discover_unit_tests(test_config)
8484
assert len(tests) == 1
8585
assert 'bubble_sort.sorter' in tests
8686
assert len(tests['bubble_sort.sorter']) == 2
@@ -119,7 +119,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
119119
)
120120

121121
# Discover tests
122-
discovered_tests = discover_unit_tests(test_config)
122+
discovered_tests, _ = discover_unit_tests(test_config)
123123

124124
# Check if the dummy test file is discovered
125125
assert len(discovered_tests) == 1
@@ -192,7 +192,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
192192
)
193193

194194
# Discover tests
195-
discovered_tests = discover_unit_tests(test_config)
195+
discovered_tests, _ = discover_unit_tests(test_config)
196196

197197
# Check if the test files at all levels are discovered
198198
assert len(discovered_tests) == 3
@@ -282,7 +282,7 @@ def test_discover_tests_pytest_dirs():
282282
)
283283

284284
# Discover tests
285-
discovered_tests = discover_unit_tests(test_config)
285+
discovered_tests, _ = discover_unit_tests(test_config)
286286

287287
# Check if the test files at all levels are discovered
288288
assert len(discovered_tests) == 4
@@ -328,7 +328,7 @@ def test_discover_tests_pytest_with_class():
328328
)
329329

330330
# Discover tests
331-
discovered_tests = discover_unit_tests(test_config)
331+
discovered_tests, _ = discover_unit_tests(test_config)
332332

333333
# Check if the test class and method are discovered
334334
assert len(discovered_tests) == 1
@@ -366,7 +366,7 @@ def test_discover_tests_pytest_with_double_nested_directories():
366366
)
367367

368368
# Discover tests
369-
discovered_tests = discover_unit_tests(test_config)
369+
discovered_tests, _ = discover_unit_tests(test_config)
370370

371371
# Check if the test class and method are discovered
372372
assert len(discovered_tests) == 1
@@ -416,7 +416,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
416416
)
417417

418418
# Discover tests
419-
discovered_tests = discover_unit_tests(test_config)
419+
discovered_tests, _ = discover_unit_tests(test_config)
420420

421421
# Check if the test file is discovered and associated with the code file
422422
assert len(discovered_tests) == 1
@@ -455,7 +455,7 @@ def test_discover_tests_pytest_with_nested_class():
455455
)
456456

457457
# Discover tests
458-
discovered_tests = discover_unit_tests(test_config)
458+
discovered_tests, _ = discover_unit_tests(test_config)
459459

460460
# Check if the test for the nested class method is discovered
461461
assert len(discovered_tests) == 1
@@ -495,7 +495,7 @@ def test_discover_tests_pytest_separate_moduledir():
495495
)
496496

497497
# Discover tests
498-
discovered_tests = discover_unit_tests(test_config)
498+
discovered_tests, _ = discover_unit_tests(test_config)
499499

500500
# Check if the test for the nested class method is discovered
501501
assert len(discovered_tests) == 1
@@ -537,7 +537,7 @@ def test_add(self):
537537
)
538538

539539
# Discover tests
540-
discovered_tests = discover_unit_tests(test_config)
540+
discovered_tests, _ = discover_unit_tests(test_config)
541541

542542
# Verify the unittest was discovered
543543
assert len(discovered_tests) == 1
@@ -604,7 +604,7 @@ def test_add(self):
604604
)
605605

606606
# Discover tests
607-
discovered_tests = discover_unit_tests(test_config)
607+
discovered_tests, _ = discover_unit_tests(test_config)
608608

609609
# Verify the unittest was discovered
610610
assert len(discovered_tests) == 2
@@ -649,7 +649,7 @@ def _test_add(self): # Private test method should not be discovered
649649
)
650650

651651
# Discover tests
652-
discovered_tests = discover_unit_tests(test_config)
652+
discovered_tests, _ = discover_unit_tests(test_config)
653653

654654
# Verify no tests were discovered
655655
assert len(discovered_tests) == 0
@@ -701,7 +701,7 @@ def test_add_with_parameters(self):
701701
)
702702

703703
# Discover tests
704-
discovered_tests = discover_unit_tests(test_config)
704+
discovered_tests, _ = discover_unit_tests(test_config)
705705

706706
# Verify the unittest was discovered
707707
assert len(discovered_tests) == 1
@@ -783,7 +783,7 @@ def test_add_mixed(self, name, a, b, expected):
783783
)
784784

785785
# Discover tests
786-
discovered_tests = discover_unit_tests(test_config)
786+
discovered_tests, _ = discover_unit_tests(test_config)
787787

788788
# Verify the basic structure
789789
assert len(discovered_tests) == 2 # Should have tests for both add and multiply

0 commit comments

Comments
 (0)