Skip to content

Commit 9f246ec

Browse files
committed
dict directly by using a set in discover_unit_tests
1 parent 3de7373 commit 9f246ec

File tree

6 files changed

+44
-41
lines changed

6 files changed

+44
-41
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def discover_unit_tests(
305305
cfg: TestConfig,
306306
discover_only_these_tests: list[Path] | None = None,
307307
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
308-
) -> tuple[dict[str, list[FunctionCalledInTest]], int]:
308+
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
309309
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
310310
strategy = framework_strategies.get(cfg.test_framework, None)
311311
if not strategy:
@@ -326,7 +326,7 @@ def discover_tests_pytest(
326326
cfg: TestConfig,
327327
discover_only_these_tests: list[Path] | None = None,
328328
functions_to_optimize: list[FunctionToOptimize] | None = None,
329-
) -> dict[str, list[FunctionCalledInTest]]:
329+
) -> dict[str, set[FunctionCalledInTest]]:
330330
tests_root = cfg.tests_root
331331
project_root = cfg.project_root_path
332332

@@ -402,7 +402,7 @@ def discover_tests_unittest(
402402
cfg: TestConfig,
403403
discover_only_these_tests: list[str] | None = None,
404404
functions_to_optimize: list[FunctionToOptimize] | None = None,
405-
) -> dict[str, list[FunctionCalledInTest]]:
405+
) -> dict[str, set[FunctionCalledInTest]]:
406406
tests_root: Path = cfg.tests_root
407407
loader: unittest.TestLoader = unittest.TestLoader()
408408
tests: unittest.TestSuite = loader.discover(str(tests_root))
@@ -469,7 +469,7 @@ def process_test_files(
469469
file_to_test_map: dict[Path, list[TestsInFile]],
470470
cfg: TestConfig,
471471
functions_to_optimize: list[FunctionToOptimize] | None = None,
472-
) -> dict[str, list[FunctionCalledInTest]]:
472+
) -> dict[str, set[FunctionCalledInTest]]:
473473
import jedi
474474

475475
project_root_path = cfg.project_root_path
@@ -637,4 +637,4 @@ def process_test_files(
637637

638638
progress.advance(task_id)
639639

640-
return {function: list(tests) for function, tests in function_to_test_map.items()}
640+
return dict(function_to_test_map)

codeflash/optimization/function_optimizer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from codeflash.models.models import (
5858
BestOptimization,
5959
CodeOptimizationContext,
60-
FunctionCalledInTest,
6160
GeneratedTests,
6261
GeneratedTestsList,
6362
OptimizationSet,
@@ -87,7 +86,13 @@
8786

8887
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
8988
from codeflash.either import Result
90-
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
89+
from codeflash.models.models import (
90+
BenchmarkKey,
91+
CoverageData,
92+
FunctionCalledInTest,
93+
FunctionSource,
94+
OptimizedCandidate,
95+
)
9196
from codeflash.verification.verification_utils import TestConfig
9297

9398

@@ -97,7 +102,7 @@ def __init__(
97102
function_to_optimize: FunctionToOptimize,
98103
test_cfg: TestConfig,
99104
function_to_optimize_source_code: str = "",
100-
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
105+
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
101106
function_to_optimize_ast: ast.FunctionDef | None = None,
102107
aiservice_client: AiServiceClient | None = None,
103108
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
@@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
213218

214219
function_to_optimize_qualified_name = self.function_to_optimize.qualified_name
215220
function_to_all_tests = {
216-
key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, [])
221+
key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set())
217222
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
218223
}
219224
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
@@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None:
690695
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
691696
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
692697

693-
def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]:
698+
def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]:
694699
existing_test_files_count = 0
695700
replay_test_files_count = 0
696701
concolic_coverage_test_files_count = 0
@@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
701706
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
702707
console.rule()
703708
else:
704-
test_file_invocation_positions = defaultdict(list[FunctionCalledInTest])
709+
test_file_invocation_positions = defaultdict(list)
705710
for tests_in_file in function_to_all_tests.get(func_qualname):
706711
test_file_invocation_positions[
707712
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
@@ -787,7 +792,7 @@ def generate_tests_and_optimizations(
787792
generated_test_paths: list[Path],
788793
generated_perf_test_paths: list[Path],
789794
run_experiment: bool = False, # noqa: FBT001, FBT002
790-
) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]:
795+
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]:
791796
assert len(generated_test_paths) == N_TESTS_TO_GENERATE
792797
max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3
793798
console.rule()

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_function_optimizer(
4848
self,
4949
function_to_optimize: FunctionToOptimize,
5050
function_to_optimize_ast: ast.FunctionDef | None = None,
51-
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
51+
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
5252
function_to_optimize_source_code: str | None = "",
5353
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
5454
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,

codeflash/result/create_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
def existing_tests_source_for(
2727
function_qualified_name_with_modules_from_root: str,
28-
function_to_tests: dict[str, list[FunctionCalledInTest]],
28+
function_to_tests: dict[str, set[FunctionCalledInTest]],
2929
tests_root: Path,
3030
) -> str:
3131
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)

codeflash/verification/concolic_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def generate_concolic_tests(
2626
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
27-
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
27+
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
2828
start_time = time.perf_counter()
2929
function_to_concolic_tests = {}
3030
concolic_test_suite_code = ""

tests/test_unit_test_discovery.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,9 @@ def test_discover_tests_pytest_with_temp_dir_root():
129129
# Check if the dummy test file is discovered
130130
assert len(discovered_tests) == 1
131131
assert len(discovered_tests["dummy_code.dummy_function"]) == 2
132-
assert discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_file == test_file_path
133-
assert discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_file == test_file_path
134-
assert {
135-
discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_function,
136-
discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_function,
137-
} == {"test_dummy_parametrized_function[True]", "test_dummy_function"}
132+
dummy_tests = discovered_tests["dummy_code.dummy_function"]
133+
assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests)
134+
assert {test.tests_in_file.test_function for test in dummy_tests} == {"test_dummy_parametrized_function[True]", "test_dummy_function"}
138135

139136

140137
def test_discover_tests_pytest_with_multi_level_dirs():
@@ -201,13 +198,13 @@ def test_discover_tests_pytest_with_multi_level_dirs():
201198

202199
# Check if the test files at all levels are discovered
203200
assert len(discovered_tests) == 3
204-
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
201+
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
205202
assert (
206-
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
203+
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path
207204
)
208205

209206
assert (
210-
discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file
207+
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
211208
== level2_test_file_path
212209
)
213210

@@ -291,17 +288,17 @@ def test_discover_tests_pytest_dirs():
291288

292289
# Check if the test files at all levels are discovered
293290
assert len(discovered_tests) == 4
294-
assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path
291+
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
295292
assert (
296-
discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path
293+
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path
297294
)
298295
assert (
299-
discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file
296+
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
300297
== level2_test_file_path
301298
)
302299

303300
assert (
304-
discovered_tests["level1.level3.level3_code.level3_function"][0].tests_in_file.test_file
301+
next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file
305302
== level3_test_file_path
306303
)
307304

@@ -337,7 +334,7 @@ def test_discover_tests_pytest_with_class():
337334

338335
# Check if the test class and method are discovered
339336
assert len(discovered_tests) == 1
340-
assert discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file == test_file_path
337+
assert next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file == test_file_path
341338

342339

343340
def test_discover_tests_pytest_with_double_nested_directories():
@@ -376,9 +373,7 @@ def test_discover_tests_pytest_with_double_nested_directories():
376373
# Check if the test class and method are discovered
377374
assert len(discovered_tests) == 1
378375
assert (
379-
discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"][
380-
0
381-
].tests_in_file.test_file
376+
next(iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])).tests_in_file.test_file
382377
== test_file_path
383378
)
384379

@@ -425,7 +420,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
425420

426421
# Check if the test file is discovered and associated with the code file
427422
assert len(discovered_tests) == 1
428-
assert discovered_tests["code.some_code.some_function"][0].tests_in_file.test_file == test_file_path
423+
assert next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file == test_file_path
429424

430425

431426
def test_discover_tests_pytest_with_nested_class():
@@ -465,7 +460,7 @@ def test_discover_tests_pytest_with_nested_class():
465460
# Check if the test for the nested class method is discovered
466461
assert len(discovered_tests) == 1
467462
assert (
468-
discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][0].tests_in_file.test_file
463+
next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file
469464
== test_file_path
470465
)
471466

@@ -504,7 +499,7 @@ def test_discover_tests_pytest_separate_moduledir():
504499

505500
# Check if the test for the nested class method is discovered
506501
assert len(discovered_tests) == 1
507-
assert discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path
502+
assert next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file == test_file_path
508503

509504

510505
def test_unittest_discovery_with_pytest():
@@ -548,8 +543,9 @@ def test_add(self):
548543
assert len(discovered_tests) == 1
549544
assert "calculator.Calculator.add" in discovered_tests
550545
assert len(discovered_tests["calculator.Calculator.add"]) == 1
551-
assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path
552-
assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add"
546+
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
547+
assert calculator_test.tests_in_file.test_file == test_file_path
548+
assert calculator_test.tests_in_file.test_function == "test_add"
553549

554550

555551
def test_unittest_discovery_with_pytest_parent_class():
@@ -615,8 +611,9 @@ def test_add(self):
615611
assert len(discovered_tests) == 2
616612
assert "calculator.Calculator.add" in discovered_tests
617613
assert len(discovered_tests["calculator.Calculator.add"]) == 1
618-
assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path
619-
assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add"
614+
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
615+
assert calculator_test.tests_in_file.test_file == test_file_path
616+
assert calculator_test.tests_in_file.test_function == "test_add"
620617

621618

622619
def test_unittest_discovery_with_pytest_private():
@@ -712,9 +709,10 @@ def test_add_with_parameters(self):
712709
assert len(discovered_tests) == 1
713710
assert "calculator.Calculator.add" in discovered_tests
714711
assert len(discovered_tests["calculator.Calculator.add"]) == 1
715-
assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path
712+
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
713+
assert calculator_test.tests_in_file.test_file == test_file_path
716714
assert (
717-
discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add_with_parameters"
715+
calculator_test.tests_in_file.test_function == "test_add_with_parameters"
718716
)
719717

720718

0 commit comments

Comments
 (0)