diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 032116da7..e691c2ed2 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -9,8 +9,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, TestingMode -from codeflash.verification.test_results import VerificationType +from codeflash.models.models import FunctionParent, TestingMode, VerificationType if TYPE_CHECKING: from collections.abc import Iterable diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 02ae2e4c1..3b05c8d49 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -16,8 +16,7 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE -from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile -from codeflash.verification.test_results import TestType +from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType if TYPE_CHECKING: from codeflash.verification.verification_utils import TestConfig diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index f266a039d..d7c12d962 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -4,7 +4,7 @@ from pydantic.dataclasses import dataclass from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.verification.test_results import TestResults +from codeflash.models.models import TestResults @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index bd7fd3e05..3d338abc8 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,29 +1,30 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from rich.tree import Tree + +from codeflash.cli_cmds.console import DEBUG_MODE + +if TYPE_CHECKING: + from collections.abc import Iterator import enum -import json import re +import sys from collections.abc import Collection, Iterator from enum import Enum, IntEnum from pathlib import Path from re import Pattern -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional, Union, cast -import sentry_sdk -from coverage.exceptions import NoDataError from jedi.api.classes import Name from pydantic import AfterValidator, BaseModel, ConfigDict, Field from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.code_utils import validate_python_code -from codeflash.code_utils.coverage_utils import ( - build_fully_qualified_name, - extract_dependent_function, - generate_candidates, -) from codeflash.code_utils.env_utils import is_end_to_end -from codeflash.verification.test_results import TestResults, TestType +from codeflash.verification.comparator import comparator # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name @@ -241,209 +242,6 @@ class CoverageData: blank_re: Pattern[str] = re.compile(r"\s*(#|$)") else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)") - @staticmethod - def load_from_sqlite_database( - database_path: Path, config_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path - ) -> CoverageData: - """Load coverage data from an SQLite database, mimicking the behavior of load_from_coverage_file.""" - from coverage import Coverage - from coverage.jsonreport import JsonReporter - - cov = Coverage(data_file=database_path,config_file=config_path, data_suffix=True, auto_data=True, branch=True) - - if not database_path.stat().st_size or not database_path.exists(): - logger.debug(f"Coverage database {database_path} is empty or does not exist") - sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist") - return CoverageData.create_empty(source_code_path, function_name, code_context) - cov.load() - - reporter = JsonReporter(cov) - temp_json_file = database_path.with_suffix(".report.json") - with temp_json_file.open("w") as f: - try: - reporter.report(morfs=[source_code_path.as_posix()], outfile=f) - except NoDataError: - sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}") - return CoverageData.create_empty(source_code_path, function_name, code_context) - with temp_json_file.open() as f: - original_coverage_data = json.load(f) - - coverage_data, status = CoverageData._parse_coverage_file(temp_json_file, source_code_path) - - main_func_coverage, dependent_func_coverage = CoverageData._fetch_function_coverages( - function_name, code_context, coverage_data, original_cov_data=original_coverage_data - ) - - total_executed_lines, total_unexecuted_lines = CoverageData._aggregate_coverage( - main_func_coverage, dependent_func_coverage - ) - - total_lines = total_executed_lines | total_unexecuted_lines - coverage = len(total_executed_lines) / len(total_lines) * 100 if total_lines else 0.0 - # coverage = (lines covered of the original function + its 1 level deep helpers) / (lines spanned by original function + its 1 level deep helpers), if no helpers then just the original function coverage - - functions_being_tested = [main_func_coverage.name] - if dependent_func_coverage: - functions_being_tested.append(dependent_func_coverage.name) - - graph = CoverageData._build_graph(main_func_coverage, dependent_func_coverage) - temp_json_file.unlink() - - return CoverageData( - file_path=source_code_path, - coverage=coverage, - function_name=function_name, - functions_being_tested=functions_being_tested, - graph=graph, - code_context=code_context, - main_func_coverage=main_func_coverage, - dependent_func_coverage=dependent_func_coverage, - status=status, - ) - - @staticmethod - def _parse_coverage_file( - coverage_file_path: Path, source_code_path: Path - ) -> tuple[dict[str, dict[str, Any]], CoverageStatus]: - with coverage_file_path.open() as f: - coverage_data = json.load(f) - - candidates = generate_candidates(source_code_path) - - logger.debug(f"Looking for coverage data in {' -> '.join(candidates)}") - for candidate in candidates: - try: - cov: dict[str, dict[str, Any]] = coverage_data["files"][candidate]["functions"] - logger.debug(f"Coverage data found for {source_code_path} in {candidate}") - status = CoverageStatus.PARSED_SUCCESSFULLY - break - except KeyError: - continue - else: - logger.debug(f"No coverage data found for {source_code_path} in {candidates}") - cov = {} - status = CoverageStatus.NOT_FOUND - return cov, status - - @staticmethod - def _fetch_function_coverages( - function_name: str, - code_context: CodeOptimizationContext, - coverage_data: dict[str, dict[str, Any]], - original_cov_data: dict[str, dict[str, Any]], - ) -> tuple[FunctionCoverage, Union[FunctionCoverage, None]]: - resolved_name = build_fully_qualified_name(function_name, code_context) - try: - main_function_coverage = FunctionCoverage( - name=resolved_name, - coverage=coverage_data[resolved_name]["summary"]["percent_covered"], - executed_lines=coverage_data[resolved_name]["executed_lines"], - unexecuted_lines=coverage_data[resolved_name]["missing_lines"], - executed_branches=coverage_data[resolved_name]["executed_branches"], - unexecuted_branches=coverage_data[resolved_name]["missing_branches"], - ) - except KeyError: - main_function_coverage = FunctionCoverage( - name=resolved_name, - coverage=0, - executed_lines=[], - unexecuted_lines=[], - executed_branches=[], - unexecuted_branches=[], - ) - - dependent_function = extract_dependent_function(function_name, code_context) - dependent_func_coverage = ( - CoverageData.grab_dependent_function_from_coverage_data( - dependent_function, coverage_data, original_cov_data - ) - if dependent_function - else None - ) - - return main_function_coverage, dependent_func_coverage - - @staticmethod - def _aggregate_coverage( - main_func_coverage: FunctionCoverage, dependent_func_coverage: Union[FunctionCoverage, None] - ) -> tuple[set[int], set[int]]: - total_executed_lines = set(main_func_coverage.executed_lines) - total_unexecuted_lines = set(main_func_coverage.unexecuted_lines) - - if dependent_func_coverage: - total_executed_lines.update(dependent_func_coverage.executed_lines) - total_unexecuted_lines.update(dependent_func_coverage.unexecuted_lines) - - return total_executed_lines, total_unexecuted_lines - - @staticmethod - def _build_graph( - main_func_coverage: FunctionCoverage, dependent_func_coverage: Union[FunctionCoverage, None] - ) -> dict[str, dict[str, Collection[object]]]: - graph = { - main_func_coverage.name: { - "executed_lines": set(main_func_coverage.executed_lines), - "unexecuted_lines": set(main_func_coverage.unexecuted_lines), - "executed_branches": main_func_coverage.executed_branches, - "unexecuted_branches": main_func_coverage.unexecuted_branches, - } - } - - if dependent_func_coverage: - graph[dependent_func_coverage.name] = { - "executed_lines": set(dependent_func_coverage.executed_lines), - "unexecuted_lines": set(dependent_func_coverage.unexecuted_lines), - "executed_branches": dependent_func_coverage.executed_branches, - "unexecuted_branches": dependent_func_coverage.unexecuted_branches, - } - - return graph - - @staticmethod - def grab_dependent_function_from_coverage_data( - dependent_function_name: str, - coverage_data: dict[str, dict[str, Any]], - original_cov_data: dict[str, dict[str, Any]], - ) -> FunctionCoverage: - """Grab the dependent function from the coverage data.""" - try: - return FunctionCoverage( - name=dependent_function_name, - coverage=coverage_data[dependent_function_name]["summary"]["percent_covered"], - executed_lines=coverage_data[dependent_function_name]["executed_lines"], - unexecuted_lines=coverage_data[dependent_function_name]["missing_lines"], - executed_branches=coverage_data[dependent_function_name]["executed_branches"], - unexecuted_branches=coverage_data[dependent_function_name]["missing_branches"], - ) - except KeyError: - msg = f"Coverage data not found for dependent function {dependent_function_name} in the coverage data" - try: - files = original_cov_data["files"] - for file in files: - functions = files[file]["functions"] - for function in functions: - if dependent_function_name in function: - return FunctionCoverage( - name=dependent_function_name, - coverage=functions[function]["summary"]["percent_covered"], - executed_lines=functions[function]["executed_lines"], - unexecuted_lines=functions[function]["missing_lines"], - executed_branches=functions[function]["executed_branches"], - unexecuted_branches=functions[function]["missing_branches"], - ) - msg = f"Coverage data not found for dependent function {dependent_function_name} in the original coverage data" - except KeyError: - raise ValueError(msg) from None - - return FunctionCoverage( - name=dependent_function_name, - coverage=0, - executed_lines=[], - unexecuted_lines=[], - executed_branches=[], - unexecuted_branches=[], - ) - def build_message(self) -> str: if self.status == CoverageStatus.NOT_FOUND: return f"No coverage data found for {self.function_name}" @@ -495,7 +293,6 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt status=CoverageStatus.NOT_FOUND, ) - @dataclass class FunctionCoverage: """Represents the coverage data for a specific function in a source file.""" @@ -511,3 +308,236 @@ class FunctionCoverage: class TestingMode(enum.Enum): BEHAVIOR = "behavior" PERFORMANCE = "performance" + + +class VerificationType(str, Enum): + FUNCTION_CALL = ( + "function_call" # Correctness verification for a test function, checks input values and output values) + ) + INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init + INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init + + +class TestType(Enum): + EXISTING_UNIT_TEST = 1 + INSPIRED_REGRESSION = 2 + GENERATED_REGRESSION = 3 + REPLAY_TEST = 4 + CONCOLIC_COVERAGE_TEST = 5 + INIT_STATE_TEST = 6 + + def to_name(self) -> str: + if self is TestType.INIT_STATE_TEST: + return "" + names = { + TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests", + TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests", + TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests", + TestType.REPLAY_TEST: "⏪ Replay Tests", + TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests", + } + return names[self] + + +@dataclass(frozen=True) +class InvocationId: + test_module_path: str # The fully qualified name of the test module + test_class_name: Optional[str] # The name of the class where the test is defined + test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name + function_getting_tested: str + iteration_id: Optional[str] + + # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id + def id(self) -> str: + class_prefix = f"{self.test_class_name}." if self.test_class_name else "" + return ( + f"{self.test_module_path}:{class_prefix}{self.test_function_name}:" + f"{self.function_getting_tested}:{self.iteration_id}" + ) + + @staticmethod + def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId: + components = string_id.split(":") + assert len(components) == 4 + second_components = components[1].split(".") + if len(second_components) == 1: + test_class_name = None + test_function_name = second_components[0] + else: + test_class_name = second_components[0] + test_function_name = second_components[1] + return InvocationId( + test_module_path=components[0], + test_class_name=test_class_name, + test_function_name=test_function_name, + function_getting_tested=components[2], + iteration_id=iteration_id if iteration_id else components[3], + ) + + +@dataclass(frozen=True) +class FunctionTestInvocation: + loop_index: int # The loop index of the function invocation, starts at 1 + id: InvocationId # The fully qualified name of the function invocation (id) + file_name: Path # The file where the test is defined + did_pass: bool # Whether the test this function invocation was part of, passed or failed + runtime: Optional[int] # Time in nanoseconds + test_framework: str # unittest or pytest + test_type: TestType + return_value: Optional[object] # The return value of the function invocation + timed_out: Optional[bool] + verification_type: Optional[str] = VerificationType.FUNCTION_CALL + stdout: Optional[str] = None + + @property + def unique_invocation_loop_id(self) -> str: + return f"{self.loop_index}:{self.id.id()}" + + +class TestResults(BaseModel): + # don't modify these directly, use the add method + # also we don't support deletion of test results elements - caution is advised + test_results: list[FunctionTestInvocation] = [] + test_result_idx: dict[str, int] = {} + + def add(self, function_test_invocation: FunctionTestInvocation) -> None: + unique_id = function_test_invocation.unique_invocation_loop_id + if unique_id in self.test_result_idx: + if DEBUG_MODE: + logger.warning(f"Test result with id {unique_id} already exists. SKIPPING") + return + self.test_result_idx[unique_id] = len(self.test_results) + self.test_results.append(function_test_invocation) + + def merge(self, other: TestResults) -> None: + original_len = len(self.test_results) + self.test_results.extend(other.test_results) + for k, v in other.test_result_idx.items(): + if k in self.test_result_idx: + msg = f"Test result with id {k} already exists." + raise ValueError(msg) + self.test_result_idx[k] = v + original_len + + def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None: + try: + return self.test_results[self.test_result_idx[unique_invocation_loop_id]] + except (IndexError, KeyError): + return None + + def get_all_ids(self) -> set[InvocationId]: + return {test_result.id for test_result in self.test_results} + + def get_all_unique_invocation_loop_ids(self) -> set[str]: + return {test_result.unique_invocation_loop_id for test_result in self.test_results} + + def number_of_loops(self) -> int: + if not self.test_results: + return 0 + return max(test_result.loop_index for test_result in self.test_results) + + def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: + report = {} + for test_type in TestType: + report[test_type] = {"passed": 0, "failed": 0} + for test_result in self.test_results: + if test_result.loop_index == 1: + if test_result.did_pass: + report[test_result.test_type]["passed"] += 1 + else: + report[test_result.test_type]["failed"] += 1 + return report + + @staticmethod + def report_to_string(report: dict[TestType, dict[str, int]]) -> str: + return " ".join( + [ + f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" + for test_type in TestType + ] + ) + + @staticmethod + def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: + tree = Tree(title) + for test_type in TestType: + if test_type is TestType.INIT_STATE_TEST: + continue + tree.add( + f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}" + ) + return tree + + def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: + for result in self.test_results: + if result.did_pass and not result.runtime: + msg = ( + f"Ignoring test case that passed but had no runtime -> {result.id}, " + f"Loop # {result.loop_index}, Test Type: {result.test_type}, " + f"Verification Type: {result.verification_type}" + ) + logger.debug(msg) + + usable_runtimes = [ + (result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime + ] + return { + usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id] + for usable_id in {runtime[0] for runtime in usable_runtimes} + } + + def total_passed_runtime(self) -> int: + """Calculate the sum of runtimes of all test cases that passed. + + A testcase runtime is the minimum value of all looped execution runtimes. + + :return: The runtime in nanoseconds. + """ + return sum( + [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] + ) + + def __iter__(self) -> Iterator[FunctionTestInvocation]: + return iter(self.test_results) + + def __len__(self) -> int: + return len(self.test_results) + + def __getitem__(self, index: int) -> FunctionTestInvocation: + return self.test_results[index] + + def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: + self.test_results[index] = value + + def __contains__(self, value: FunctionTestInvocation) -> bool: + return value in self.test_results + + def __bool__(self) -> bool: + return bool(self.test_results) + + def __eq__(self, other: object) -> bool: + # Unordered comparison + if type(self) is not type(other): + return False + if len(self) != len(other): + return False + original_recursion_limit = sys.getrecursionlimit() + cast(TestResults, other) + for test_result in self: + other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) + if other_test_result is None: + return False + + if original_recursion_limit < 5000: + sys.setrecursionlimit(5000) + if ( + test_result.file_name != other_test_result.file_name + or test_result.did_pass != other_test_result.did_pass + or test_result.runtime != other_test_result.runtime + or test_result.test_framework != other_test_result.test_framework + or test_result.test_type != other_test_result.test_type + or not comparator(test_result.return_value, other_test_result.return_value) + ): + sys.setrecursionlimit(original_recursion_limit) + return False + sys.setrecursionlimit(original_recursion_limit) + return True \ No newline at end of file diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d5c2651b7..36d6c6f76 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -26,8 +26,8 @@ cleanup_paths, file_name_from_test_module_name, get_run_tmp_file, - module_name_from_file_path, has_any_async_functions, + module_name_from_file_path, ) from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, @@ -56,6 +56,8 @@ TestFile, TestFiles, TestingMode, + TestResults, + TestType, ) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic @@ -65,7 +67,6 @@ from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.parse_test_output import parse_test_results -from codeflash.verification.test_results import TestResults, TestType from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 7e6848010..9adf3723f 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -16,10 +16,9 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.either import is_successful -from codeflash.models.models import ValidCode +from codeflash.models.models import TestType, ValidCode from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.telemetry.posthog_cf import ph -from codeflash.verification.test_results import TestType from codeflash.verification.verification_utils import TestConfig if TYPE_CHECKING: diff --git a/codeflash/result/critic.py b/codeflash/result/critic.py index 5d106ab98..9f5d1ea65 100644 --- a/codeflash/result/critic.py +++ b/codeflash/result/critic.py @@ -1,10 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from codeflash.cli_cmds.console import logger from codeflash.code_utils import env_utils from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD -from codeflash.models.models import CoverageData, OptimizedCandidateResult -from codeflash.verification.test_results import TestType +from codeflash.models.models import TestType + +if TYPE_CHECKING: + from codeflash.models.models import CoverageData, OptimizedCandidateResult def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float: diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index 1dd53ceb5..8a2f8f81d 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -3,7 +3,7 @@ from pydantic.dataclasses import dataclass from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.verification.test_results import TestResults +from codeflash.models.models import TestResults @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index e92cacf07..0f2e285aa 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -10,7 +10,7 @@ import dill as pickle -from codeflash.verification.test_results import VerificationType +from codeflash.models.models import VerificationType def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]: diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py new file mode 100644 index 000000000..2ef03cf5f --- /dev/null +++ b/codeflash/verification/coverage_utils.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Union + +import sentry_sdk +from coverage.exceptions import NoDataError + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.coverage_utils import ( + build_fully_qualified_name, + extract_dependent_function, + generate_candidates, +) +from codeflash.models.models import CoverageData, CoverageStatus, FunctionCoverage + +if TYPE_CHECKING: + from collections.abc import Collection + from pathlib import Path + + from codeflash.models.models import CodeOptimizationContext + + +class CoverageUtils: + """Coverage utils class for interfacing with Coverage.""" + + @staticmethod + def load_from_sqlite_database( + database_path: Path, config_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path + ) -> CoverageData: + """Load coverage data from an SQLite database, mimicking the behavior of load_from_coverage_file.""" + from coverage import Coverage + from coverage.jsonreport import JsonReporter + + cov = Coverage(data_file=database_path,config_file=config_path, data_suffix=True, auto_data=True, branch=True) + + if not database_path.stat().st_size or not database_path.exists(): + logger.debug(f"Coverage database {database_path} is empty or does not exist") + sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist") + return CoverageUtils.create_empty(source_code_path, function_name, code_context) + cov.load() + + reporter = JsonReporter(cov) + temp_json_file = database_path.with_suffix(".report.json") + with temp_json_file.open("w") as f: + try: + reporter.report(morfs=[source_code_path.as_posix()], outfile=f) + except NoDataError: + sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}") + return CoverageUtils.create_empty(source_code_path, function_name, code_context) + with temp_json_file.open() as f: + original_coverage_data = json.load(f) + + coverage_data, status = CoverageUtils._parse_coverage_file(temp_json_file, source_code_path) + + main_func_coverage, dependent_func_coverage = CoverageUtils._fetch_function_coverages( + function_name, code_context, coverage_data, original_cov_data=original_coverage_data + ) + + total_executed_lines, total_unexecuted_lines = CoverageUtils._aggregate_coverage( + main_func_coverage, dependent_func_coverage + ) + + total_lines = total_executed_lines | total_unexecuted_lines + coverage = len(total_executed_lines) / len(total_lines) * 100 if total_lines else 0.0 + # coverage = (lines covered of the original function + its 1 level deep helpers) / (lines spanned by original function + its 1 level deep helpers), if no helpers then just the original function coverage + + functions_being_tested = [main_func_coverage.name] + if dependent_func_coverage: + functions_being_tested.append(dependent_func_coverage.name) + + graph = CoverageUtils._build_graph(main_func_coverage, dependent_func_coverage) + temp_json_file.unlink() + + return CoverageData( + file_path=source_code_path, + coverage=coverage, + function_name=function_name, + functions_being_tested=functions_being_tested, + graph=graph, + code_context=code_context, + main_func_coverage=main_func_coverage, + dependent_func_coverage=dependent_func_coverage, + status=status, + ) + + @staticmethod + def _parse_coverage_file( + coverage_file_path: Path, source_code_path: Path + ) -> tuple[dict[str, dict[str, Any]], CoverageStatus]: + with coverage_file_path.open() as f: + coverage_data = json.load(f) + + candidates = generate_candidates(source_code_path) + + logger.debug(f"Looking for coverage data in {' -> '.join(candidates)}") + for candidate in candidates: + try: + cov: dict[str, dict[str, Any]] = coverage_data["files"][candidate]["functions"] + logger.debug(f"Coverage data found for {source_code_path} in {candidate}") + status = CoverageStatus.PARSED_SUCCESSFULLY + break + except KeyError: + continue + else: + logger.debug(f"No coverage data found for {source_code_path} in {candidates}") + cov = {} + status = CoverageStatus.NOT_FOUND + return cov, status + + @staticmethod + def _fetch_function_coverages( + function_name: str, + code_context: CodeOptimizationContext, + coverage_data: dict[str, dict[str, Any]], + original_cov_data: dict[str, dict[str, Any]], + ) -> tuple[FunctionCoverage, Union[FunctionCoverage, None]]: + resolved_name = build_fully_qualified_name(function_name, code_context) + try: + main_function_coverage = FunctionCoverage( + name=resolved_name, + coverage=coverage_data[resolved_name]["summary"]["percent_covered"], + executed_lines=coverage_data[resolved_name]["executed_lines"], + unexecuted_lines=coverage_data[resolved_name]["missing_lines"], + executed_branches=coverage_data[resolved_name]["executed_branches"], + unexecuted_branches=coverage_data[resolved_name]["missing_branches"], + ) + except KeyError: + main_function_coverage = FunctionCoverage( + name=resolved_name, + coverage=0, + executed_lines=[], + unexecuted_lines=[], + executed_branches=[], + unexecuted_branches=[], + ) + + dependent_function = extract_dependent_function(function_name, code_context) + dependent_func_coverage = ( + CoverageUtils.grab_dependent_function_from_coverage_data( + dependent_function, coverage_data, original_cov_data + ) + if dependent_function + else None + ) + + return main_function_coverage, dependent_func_coverage + + @staticmethod + def _aggregate_coverage( + main_func_coverage: FunctionCoverage, dependent_func_coverage: Union[FunctionCoverage, None] + ) -> tuple[set[int], set[int]]: + total_executed_lines = set(main_func_coverage.executed_lines) + total_unexecuted_lines = set(main_func_coverage.unexecuted_lines) + + if dependent_func_coverage: + total_executed_lines.update(dependent_func_coverage.executed_lines) + total_unexecuted_lines.update(dependent_func_coverage.unexecuted_lines) + + return total_executed_lines, total_unexecuted_lines + + @staticmethod + def _build_graph( + main_func_coverage: FunctionCoverage, dependent_func_coverage: Union[FunctionCoverage, None] + ) -> dict[str, dict[str, Collection[object]]]: + graph = { + main_func_coverage.name: { + "executed_lines": set(main_func_coverage.executed_lines), + "unexecuted_lines": set(main_func_coverage.unexecuted_lines), + "executed_branches": main_func_coverage.executed_branches, + "unexecuted_branches": main_func_coverage.unexecuted_branches, + } + } + + if dependent_func_coverage: + graph[dependent_func_coverage.name] = { + "executed_lines": set(dependent_func_coverage.executed_lines), + "unexecuted_lines": set(dependent_func_coverage.unexecuted_lines), + "executed_branches": dependent_func_coverage.executed_branches, + "unexecuted_branches": dependent_func_coverage.unexecuted_branches, + } + + return graph + + @staticmethod + def grab_dependent_function_from_coverage_data( + dependent_function_name: str, + coverage_data: dict[str, dict[str, Any]], + original_cov_data: dict[str, dict[str, Any]], + ) -> FunctionCoverage: + """Grab the dependent function from the coverage data.""" + try: + return FunctionCoverage( + name=dependent_function_name, + coverage=coverage_data[dependent_function_name]["summary"]["percent_covered"], + executed_lines=coverage_data[dependent_function_name]["executed_lines"], + unexecuted_lines=coverage_data[dependent_function_name]["missing_lines"], + executed_branches=coverage_data[dependent_function_name]["executed_branches"], + unexecuted_branches=coverage_data[dependent_function_name]["missing_branches"], + ) + except KeyError: + msg = f"Coverage data not found for dependent function {dependent_function_name} in the coverage data" + try: + files = original_cov_data["files"] + for file in files: + functions = files[file]["functions"] + for function in functions: + if dependent_function_name in function: + return FunctionCoverage( + name=dependent_function_name, + coverage=functions[function]["summary"]["percent_covered"], + executed_lines=functions[function]["executed_lines"], + unexecuted_lines=functions[function]["missing_lines"], + executed_branches=functions[function]["executed_branches"], + unexecuted_branches=functions[function]["missing_branches"], + ) + msg = f"Coverage data not found for dependent function {dependent_function_name} in the original coverage data" + except KeyError: + raise ValueError(msg) from None + + return FunctionCoverage( + name=dependent_function_name, + coverage=0, + executed_lines=[], + unexecuted_lines=[], + executed_branches=[], + unexecuted_branches=[], + ) + diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index c3f19df02..853b2d418 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,9 +1,8 @@ -import difflib import sys -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import logger +from codeflash.models.models import TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator -from codeflash.verification.test_results import TestResults, TestType, VerificationType INCREASED_RECURSION_LIMIT = 5000 diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index b97d383d6..a0afd7254 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -20,19 +20,13 @@ module_name_from_file_path, ) from codeflash.discovery.discover_unit_tests import discover_parameters_unittest -from codeflash.models.models import CoverageData, TestFiles -from codeflash.verification.test_results import ( - FunctionTestInvocation, - InvocationId, - TestResults, - TestType, - VerificationType, -) +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType +from codeflash.verification.coverage_utils import CoverageUtils if TYPE_CHECKING: import subprocess - from codeflash.models.models import CodeOptimizationContext + from codeflash.models.models import CodeOptimizationContext, CoverageData, TestFiles from codeflash.verification.verification_utils import TestConfig @@ -522,7 +516,7 @@ def parse_test_results( all_args = False if coverage_database_file and source_file and code_context and function_name: all_args = True - coverage = CoverageData.load_from_sqlite_database( + coverage = CoverageUtils.load_from_sqlite_database( database_path=coverage_database_file, config_path=coverage_config_file, source_code_path=source_file, diff --git a/codeflash/verification/test_results.py b/codeflash/verification/test_results.py deleted file mode 100644 index 874c38072..000000000 --- a/codeflash/verification/test_results.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import sys -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Optional, cast - -from pydantic import BaseModel -from pydantic.dataclasses import dataclass -from rich.tree import Tree - -from codeflash.cli_cmds.console import DEBUG_MODE, logger -from codeflash.verification.comparator import comparator - -if TYPE_CHECKING: - from collections.abc import Iterator - - -class VerificationType(str, Enum): - FUNCTION_CALL = ( - "function_call" # Correctness verification for a test function, checks input values and output values) - ) - INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init - INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init - - -class TestType(Enum): - EXISTING_UNIT_TEST = 1 - INSPIRED_REGRESSION = 2 - GENERATED_REGRESSION = 3 - REPLAY_TEST = 4 - CONCOLIC_COVERAGE_TEST = 5 - INIT_STATE_TEST = 6 - - def to_name(self) -> str: - if self is TestType.INIT_STATE_TEST: - return "" - names = { - TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests", - TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests", - TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests", - TestType.REPLAY_TEST: "⏪ Replay Tests", - TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests", - } - return names[self] - - -@dataclass(frozen=True) -class InvocationId: - test_module_path: str # The fully qualified name of the test module - test_class_name: Optional[str] # The name of the class where the test is defined - test_function_name: Optional[str] # The name of the test_function. Does not include the components of the file_name - function_getting_tested: str - iteration_id: Optional[str] - - # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id - def id(self) -> str: - class_prefix = f"{self.test_class_name}." if self.test_class_name else "" - return ( - f"{self.test_module_path}:{class_prefix}{self.test_function_name}:" - f"{self.function_getting_tested}:{self.iteration_id}" - ) - - @staticmethod - def from_str_id(string_id: str, iteration_id: Optional[str] = None) -> InvocationId: - components = string_id.split(":") - assert len(components) == 4 - second_components = components[1].split(".") - if len(second_components) == 1: - test_class_name = None - test_function_name = second_components[0] - else: - test_class_name = second_components[0] - test_function_name = second_components[1] - return InvocationId( - test_module_path=components[0], - test_class_name=test_class_name, - test_function_name=test_function_name, - function_getting_tested=components[2], - iteration_id=iteration_id if iteration_id else components[3], - ) - - -@dataclass(frozen=True) -class FunctionTestInvocation: - loop_index: int # The loop index of the function invocation, starts at 1 - id: InvocationId # The fully qualified name of the function invocation (id) - file_name: Path # The file where the test is defined - did_pass: bool # Whether the test this function invocation was part of, passed or failed - runtime: Optional[int] # Time in nanoseconds - test_framework: str # unittest or pytest - test_type: TestType - return_value: Optional[object] # The return value of the function invocation - timed_out: Optional[bool] - verification_type: Optional[str] = VerificationType.FUNCTION_CALL - stdout: Optional[str] = None - - @property - def unique_invocation_loop_id(self) -> str: - return f"{self.loop_index}:{self.id.id()}" - - -class TestResults(BaseModel): - # don't modify these directly, use the add method - # also we don't support deletion of test results elements - caution is advised - test_results: list[FunctionTestInvocation] = [] - test_result_idx: dict[str, int] = {} - - def add(self, function_test_invocation: FunctionTestInvocation) -> None: - unique_id = function_test_invocation.unique_invocation_loop_id - if unique_id in self.test_result_idx: - if DEBUG_MODE: - logger.warning(f"Test result with id {unique_id} already exists. SKIPPING") - return - self.test_result_idx[unique_id] = len(self.test_results) - self.test_results.append(function_test_invocation) - - def merge(self, other: TestResults) -> None: - original_len = len(self.test_results) - self.test_results.extend(other.test_results) - for k, v in other.test_result_idx.items(): - if k in self.test_result_idx: - msg = f"Test result with id {k} already exists." - raise ValueError(msg) - self.test_result_idx[k] = v + original_len - - def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None: - try: - return self.test_results[self.test_result_idx[unique_invocation_loop_id]] - except (IndexError, KeyError): - return None - - def get_all_ids(self) -> set[InvocationId]: - return {test_result.id for test_result in self.test_results} - - def get_all_unique_invocation_loop_ids(self) -> set[str]: - return {test_result.unique_invocation_loop_id for test_result in self.test_results} - - def number_of_loops(self) -> int: - if not self.test_results: - return 0 - return max(test_result.loop_index for test_result in self.test_results) - - def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: - report = {} - for test_type in TestType: - report[test_type] = {"passed": 0, "failed": 0} - for test_result in self.test_results: - if test_result.loop_index == 1: - if test_result.did_pass: - report[test_result.test_type]["passed"] += 1 - else: - report[test_result.test_type]["failed"] += 1 - return report - - @staticmethod - def report_to_string(report: dict[TestType, dict[str, int]]) -> str: - return " ".join( - [ - f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" - for test_type in TestType - ] - ) - - @staticmethod - def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: - tree = Tree(title) - for test_type in TestType: - if test_type is TestType.INIT_STATE_TEST: - continue - tree.add( - f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}" - ) - return tree - - def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: - for result in self.test_results: - if result.did_pass and not result.runtime: - msg = ( - f"Ignoring test case that passed but had no runtime -> {result.id}, " - f"Loop # {result.loop_index}, Test Type: {result.test_type}, " - f"Verification Type: {result.verification_type}" - ) - logger.debug(msg) - - usable_runtimes = [ - (result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime - ] - return { - usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id] - for usable_id in {runtime[0] for runtime in usable_runtimes} - } - - def total_passed_runtime(self) -> int: - """Calculate the sum of runtimes of all test cases that passed. - - A testcase runtime is the minimum value of all looped execution runtimes. - - :return: The runtime in nanoseconds. - """ - return sum( - [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] - ) - - def __iter__(self) -> Iterator[FunctionTestInvocation]: - return iter(self.test_results) - - def __len__(self) -> int: - return len(self.test_results) - - def __getitem__(self, index: int) -> FunctionTestInvocation: - return self.test_results[index] - - def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: - self.test_results[index] = value - - def __contains__(self, value: FunctionTestInvocation) -> bool: - return value in self.test_results - - def __bool__(self) -> bool: - return bool(self.test_results) - - def __eq__(self, other: object) -> bool: - # Unordered comparison - if type(self) is not type(other): - return False - if len(self) != len(other): - return False - original_recursion_limit = sys.getrecursionlimit() - cast(TestResults, other) - for test_result in self: - other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) - if other_test_result is None: - return False - - if original_recursion_limit < 5000: - sys.setrecursionlimit(5000) - if ( - test_result.file_name != other_test_result.file_name - or test_result.did_pass != other_test_result.did_pass - or test_result.runtime != other_test_result.runtime - or test_result.test_framework != other_test_result.test_framework - or test_result.test_type != other_test_result.test_type - or not comparator(test_result.return_value, other_test_result.return_value) - ): - sys.setrecursionlimit(original_recursion_limit) - return False - sys.setrecursionlimit(original_recursion_limit) - return True diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 600c72042..949b6569e 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -10,8 +10,7 @@ from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files -from codeflash.models.models import TestFiles -from codeflash.verification.test_results import TestType +from codeflash.models.models import TestFiles, TestType if TYPE_CHECKING: from codeflash.models.models import TestFiles diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 8e5e237cf..469d1be6a 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -7,11 +7,10 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode +from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture -from codeflash.verification.test_results import TestType, VerificationType from codeflash.verification.test_runner import execute_test_subprocess from codeflash.verification.verification_utils import TestConfig diff --git a/tests/test_comparator.py b/tests/test_comparator.py index de0d753a0..0f8ace054 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -12,9 +12,9 @@ import pytest from codeflash.either import Failure, Success +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType from codeflash.verification.comparator import comparator from codeflash.verification.equivalence import compare_test_results -from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType def test_basic_python_objects() -> None: diff --git a/tests/test_critic.py b/tests/test_critic.py index 907d3f048..e60047125 100644 --- a/tests/test_critic.py +++ b/tests/test_critic.py @@ -8,10 +8,13 @@ CoverageData, CoverageStatus, FunctionCoverage, + FunctionTestInvocation, + InvocationId, OptimizedCandidateResult, + TestResults, + TestType, ) from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic -from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType def test_performance_gain() -> None: diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ce06c855a..5bc942fdd 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -9,11 +9,10 @@ from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode +from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture -from codeflash.verification.test_results import TestType # Used by cli instrumentation codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index e64fea1cf..16be1966e 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -13,9 +13,16 @@ inject_profiling_into_existing_test, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestsInFile +from codeflash.models.models import ( + CodePosition, + FunctionParent, + TestFile, + TestFiles, + TestingMode, + TestsInFile, + TestType, +) from codeflash.optimization.function_optimizer import FunctionOptimizer -from codeflash.verification.test_results import TestType from codeflash.verification.verification_utils import TestConfig codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index 888c629ef..c1a759681 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -10,11 +10,10 @@ from code_to_optimize.bubble_sort_method import BubbleSorter from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode +from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture -from codeflash.verification.test_results import TestType, VerificationType # Used by aiservice instrumentation behavior_logging_code = """from __future__ import annotations diff --git a/tests/test_merge_test_results.py b/tests/test_merge_test_results.py index 203d82dd5..f5eb4f3f9 100644 --- a/tests/test_merge_test_results.py +++ b/tests/test_merge_test_results.py @@ -1,5 +1,5 @@ +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType from codeflash.verification.parse_test_output import merge_test_results -from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType def test_merge_test_results_1(): diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index fa574a826..2c67a644c 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -2,9 +2,8 @@ import tempfile from pathlib import Path -from codeflash.models.models import TestFile, TestFiles +from codeflash.models.models import TestFile, TestFiles, TestType from codeflash.verification.parse_test_output import parse_test_xml -from codeflash.verification.test_results import TestType from codeflash.verification.test_runner import run_behavioral_tests from codeflash.verification.verification_utils import TestConfig