diff --git a/codeflash/models/models.py b/codeflash/models/models.py index ddaccd16e..aabf377df 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -561,6 +561,7 @@ def total_passed_runtime(self) -> int: :return: The runtime in nanoseconds. """ + #TODO this doesn't look at the intersection of tests of baseline and original return sum( [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] ) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 67b9de439..b7ce6978a 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -70,7 +70,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR are_equal = False break - if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and ( + if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST} and ( cdd_test_result.did_pass != original_test_result.did_pass ): are_equal = False diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 0f8ace054..4a4d9f2b1 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1009,6 +1009,94 @@ def test_compare_results_fn(): assert not compare_test_results(original_results, new_results_4) + new_results_5_baseline = TestResults() + new_results_5_baseline.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="test_module_path", + test_class_name="test_class_name", + test_function_name="test_function_name", + function_getting_tested="function_getting_tested", + iteration_id="0", + ), + file_name=Path("file_name"), + did_pass=True, + runtime=5, + test_framework="unittest", + test_type=TestType.GENERATED_REGRESSION, + return_value=5, + timed_out=False, + loop_index=1, + ) + ) + + new_results_5_opt = TestResults() + new_results_5_opt.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="test_module_path", + test_class_name="test_class_name", + test_function_name="test_function_name", + function_getting_tested="function_getting_tested", + iteration_id="0", + ), + file_name=Path("file_name"), + did_pass=False, + runtime=5, + test_framework="unittest", + test_type=TestType.GENERATED_REGRESSION, + return_value=5, + timed_out=False, + loop_index=1, + ) + ) + + assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + + new_results_6_baseline = TestResults() + new_results_6_baseline.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="test_module_path", + test_class_name="test_class_name", + test_function_name="test_function_name", + function_getting_tested="function_getting_tested", + iteration_id="0", + ), + file_name=Path("file_name"), + did_pass=True, + runtime=5, + test_framework="unittest", + test_type=TestType.REPLAY_TEST, + return_value=5, + timed_out=False, + loop_index=1, + ) + ) + + new_results_6_opt = TestResults() + new_results_6_opt.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="test_module_path", + test_class_name="test_class_name", + test_function_name="test_function_name", + function_getting_tested="function_getting_tested", + iteration_id="0", + ), + file_name=Path("file_name"), + did_pass=False, + runtime=5, + test_framework="unittest", + test_type=TestType.REPLAY_TEST, + return_value=5, + timed_out=False, + loop_index=1, + ) + ) + + assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not compare_test_results(TestResults(), TestResults())