From af8a7357ad15f2414ef3bf4edc239ec9c1e105c9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 25 Mar 2025 19:21:39 -0500 Subject: [PATCH] check for GH App --- codeflash/api/cfapi.py | 4 +- codeflash/cli_cmds/cli.py | 24 ++-- codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/cli_cmds/console.py | 3 +- codeflash/code_utils/code_replacer.py | 6 +- codeflash/code_utils/config_parser.py | 4 +- codeflash/context/code_context_extractor.py | 115 +++++++++++------- .../discovery/pytest_new_process_discovery.py | 3 +- codeflash/models/models.py | 26 ++-- codeflash/optimization/function_context.py | 7 +- codeflash/tracing/replay_test.py | 2 +- codeflash/verification/comparator.py | 2 +- codeflash/verification/coverage_utils.py | 9 +- codeflash/verification/parse_test_output.py | 1 - codeflash/verification/test_runner.py | 22 +++- codeflash/verification/verification_utils.py | 4 +- pyproject.toml | 2 +- 17 files changed, 152 insertions(+), 84 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index bc2519529..00d324db9 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional import requests +import sentry_sdk from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import console, logger @@ -194,7 +195,8 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]: req.raise_for_status() content: dict[str, list[str]] = req.json() except Exception as e: - logger.error(f"Error getting blocklisted functions: {e}", exc_info=True) + logger.error(f"Error getting blocklisted functions: {e}") + sentry_sdk.capture_exception(e) return {} return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()} diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 012fc86eb..6ac4db420 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -128,14 +128,22 @@ def process_pyproject_config(args: Namespace) -> Namespace: assert args.tests_root is not None, "--tests-root must be specified" assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory" - assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), ( - "Codeflash API key not found. When running in a Github Actions Context, provide the " - "'CODEFLASH_API_KEY' environment variable as a secret.\n" - "You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n" - "Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n" - f"Here's a direct link: {get_github_secrets_page_url()}\n" - "Exiting..." - ) + if env_utils.get_pr_number() is not None: + assert env_utils.ensure_codeflash_api_key(), ( + "Codeflash API key not found. When running in a Github Actions Context, provide the " + "'CODEFLASH_API_KEY' environment variable as a secret.\n" + "You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n" + "Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n" + f"Here's a direct link: {get_github_secrets_page_url()}\n" + "Exiting..." + ) + + repo = git.Repo(search_parent_directories=True) + + owner, repo_name = get_repo_owner_and_name(repo) + + require_github_app_or_exit(owner, repo_name) + if hasattr(args, "ignore_paths") and args.ignore_paths is not None: normalized_ignore_paths = [] for path in args.ignore_paths: diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index feb274752..5df70828c 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -393,7 +393,7 @@ def check_for_toml_or_setup_file() -> str | None: return cast(str, project_name) -def install_github_actions(override_formatter_check: bool=False) -> None: +def install_github_actions(override_formatter_check: bool = False) -> None: try: config, config_file_path = parse_config_file(override_formatter_check=override_formatter_check) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 45959ded2..7cd09c843 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -27,7 +27,8 @@ ) logger = logging.getLogger("rich") -logging.getLogger('parso').setLevel(logging.WARNING) +logging.getLogger("parso").setLevel(logging.WARNING) + def paneled_text( text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 86f9bfb02..ad37bfbd2 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -159,7 +159,7 @@ def replace_functions_in_file( source_code: str, original_function_names: list[str], optimized_code: str, - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], ) -> str: parsed_function_names = [] for original_function_name in original_function_names: @@ -195,7 +195,7 @@ def replace_functions_and_add_imports( function_names: list[str], optimized_code: str, module_abspath: Path, - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> str: return add_needed_imports_from_module( @@ -211,7 +211,7 @@ def replace_function_definitions_in_module( function_names: list[str], optimized_code: str, module_abspath: Path, - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index f8af04d44..d814f12d0 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -31,7 +31,9 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) -def parse_config_file(config_file_path: Path | None = None, override_formatter_check: bool=False) -> tuple[dict[str, Any], Path]: +def parse_config_file( + config_file_path: Path | None = None, override_formatter_check: bool = False +) -> tuple[dict[str, Any], Path]: config_file_path = find_pyproject_toml(config_file_path) try: with config_file_path.open("rb") as f: diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index e58b372d6..3827239ed 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -26,10 +26,15 @@ def get_code_optimization_context( - function_to_optimize: FunctionToOptimize, project_root_path: Path, optim_token_limit: int = 8000, testgen_token_limit: int = 8000 + function_to_optimize: FunctionToOptimize, + project_root_path: Path, + optim_token_limit: int = 8000, + testgen_token_limit: int = 8000, ) -> CodeOptimizationContext: # Get FunctionSource representation of helpers of FTO - helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi({function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path) + helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( + {function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path + ) # Add function to optimize into helpers of FTO dict, as they'll be processed together fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path) @@ -37,20 +42,27 @@ def get_code_optimization_context( # Format data to search for helpers of helpers using get_function_sources_from_jedi helpers_of_fto_qualified_names_dict = { - file_path: {source.qualified_name for source in sources} - for file_path, sources in helpers_of_fto_dict.items() + file_path: {source.qualified_name for source in sources} for file_path, sources in helpers_of_fto_dict.items() } # __init__ functions are automatically considered as helpers of FTO, so we add them to the dict (regardless of whether they exist) # This helps us to search for helpers of __init__ functions of classes that contain helpers of FTO for qualified_names in helpers_of_fto_qualified_names_dict.values(): - qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if '.' in qn}) + qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn}) # Get FunctionSource representation of helpers of helpers of FTO - helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(helpers_of_fto_qualified_names_dict, project_root_path) + helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi( + helpers_of_fto_qualified_names_dict, project_root_path + ) # Extract code context for optimization - final_read_writable_code = extract_code_string_context_from_files(helpers_of_fto_dict,{}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE).code + final_read_writable_code = extract_code_string_context_from_files( + helpers_of_fto_dict, + {}, + project_root_path, + remove_docstrings=False, + code_context_type=CodeContextType.READ_WRITABLE, + ).code read_only_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, @@ -80,10 +92,7 @@ def get_code_optimization_context( logger.debug("Code context has exceeded token limit, removing docstrings from read-only code") # Extract read only code without docstrings read_only_code_no_docstring_markdown = extract_code_markdown_context_from_files( - helpers_of_fto_dict, - helpers_of_helpers_dict, - project_root_path, - remove_docstrings=True, + helpers_of_fto_dict, helpers_of_helpers_dict, project_root_path, remove_docstrings=True ) read_only_context_code = read_only_code_no_docstring_markdown.markdown read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_context_code)) @@ -116,13 +125,14 @@ def get_code_optimization_context( raise ValueError("Testgen code context has exceeded token limit, cannot proceed") return CodeOptimizationContext( - testgen_context_code = testgen_context_code, + testgen_context_code=testgen_context_code, read_writable_code=final_read_writable_code, read_only_context_code=read_only_context_code, helper_functions=helpers_of_fto_list, preexisting_objects=preexisting_objects, ) + def extract_code_string_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -169,9 +179,15 @@ def extract_code_string_context_from_files( continue try: qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + helpers_of_helpers_qualified_names = { + func.qualified_name for func in helpers_of_helpers.get(file_path, set()) + } code_context = parse_code_and_prune_cst( - original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings + original_code, + code_context_type, + qualified_function_names, + helpers_of_helpers_qualified_names, + remove_docstrings, ) except ValueError as e: @@ -180,12 +196,12 @@ def extract_code_string_context_from_files( if code_context.strip(): final_code_string_context += f"\n{code_context}" final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions= list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())), ) if code_context_type == CodeContextType.READ_WRITABLE: return CodeString(code=final_code_string_context) @@ -199,7 +215,7 @@ def extract_code_string_context_from_files( try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -208,15 +224,16 @@ def extract_code_string_context_from_files( if code_context.strip(): final_code_string_context += f"\n{code_context}" final_code_string_context = add_needed_imports_from_module( - src_module_code=original_code, - dst_module_code=final_code_string_context, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), + src_module_code=original_code, + dst_module_code=final_code_string_context, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())), ) return CodeString(code=final_code_string_context) + def extract_code_markdown_context_from_files( helpers_of_fto: dict[Path, set[FunctionSource]], helpers_of_helpers: dict[Path, set[FunctionSource]], @@ -263,9 +280,15 @@ def extract_code_markdown_context_from_files( continue try: qualified_function_names = {func.qualified_name for func in function_sources} - helpers_of_helpers_qualified_names = {func.qualified_name for func in helpers_of_helpers.get(file_path, set())} + helpers_of_helpers_qualified_names = { + func.qualified_name for func in helpers_of_helpers.get(file_path, set()) + } code_context = parse_code_and_prune_cst( - original_code, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings + original_code, + code_context_type, + qualified_function_names, + helpers_of_helpers_qualified_names, + remove_docstrings, ) except ValueError as e: @@ -280,7 +303,8 @@ def extract_code_markdown_context_from_files( dst_path=file_path, project_root=project_root_path, helper_functions=list( - helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())) + helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set()) + ), ), file_path=file_path.relative_to(project_root_path), ) @@ -295,7 +319,7 @@ def extract_code_markdown_context_from_files( try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings, + original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -317,8 +341,9 @@ def extract_code_markdown_context_from_files( return code_context_markdown -def get_function_to_optimize_as_function_source(function_to_optimize: FunctionToOptimize, - project_root_path: Path) -> FunctionSource: +def get_function_to_optimize_as_function_source( + function_to_optimize: FunctionToOptimize, project_root_path: Path +) -> FunctionSource: # Use jedi to find function to optimize script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path)) @@ -327,11 +352,12 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo # Find the name that matches our function for name in names: - if (name.type == "function" and - name.full_name and - name.name == function_to_optimize.function_name and - get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name): - + if ( + name.type == "function" + and name.full_name + and name.name == function_to_optimize.function_name + and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name + ): function_source = FunctionSource( file_path=function_to_optimize.file_path, qualified_name=function_to_optimize.qualified_name, @@ -343,7 +369,8 @@ def get_function_to_optimize_as_function_source(function_to_optimize: FunctionTo return function_source raise ValueError( - f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}") + f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" + ) def get_function_sources_from_jedi( @@ -417,8 +444,13 @@ def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode return indented_block.with_changes(body=indented_block.body[1:]) return indented_block + def parse_code_and_prune_cst( - code: str, code_context_type: CodeContextType, target_functions: set[str], helpers_of_helper_functions: set[str] = set(), remove_docstrings: bool = False + code: str, + code_context_type: CodeContextType, + target_functions: set[str], + helpers_of_helper_functions: set[str] = set(), + remove_docstrings: bool = False, ) -> str: """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables.""" module = cst.parse_module(code) @@ -441,6 +473,7 @@ def parse_code_and_prune_cst( return str(filtered_node.code) return "" + def prune_cst_for_read_writable_code( node: cst.CSTNode, target_functions: set[str], prefix: str = "" ) -> tuple[cst.CSTNode | None, bool]: @@ -520,6 +553,7 @@ def prune_cst_for_read_writable_code( return (node.with_changes(**updates) if updates else node), True + def prune_cst_for_read_only_code( node: cst.CSTNode, target_functions: set[str], @@ -624,7 +658,6 @@ def prune_cst_for_read_only_code( return None, False - def prune_cst_for_testgen_code( node: cst.CSTNode, target_functions: set[str], diff --git a/codeflash/discovery/pytest_new_process_discovery.py b/codeflash/discovery/pytest_new_process_discovery.py index 47128b38d..2d8583255 100644 --- a/codeflash/discovery/pytest_new_process_discovery.py +++ b/codeflash/discovery/pytest_new_process_discovery.py @@ -22,6 +22,7 @@ def pytest_collection_modifyitems(config, items): if "benchmark" in item.fixturenames: item.add_marker(skip_benchmark) + def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]: test_results = [] for test in pytest_tests: @@ -39,7 +40,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s try: exitcode = pytest.main( - [tests_root, "-p no:logging", "--collect-only", "-m", "not skip",], plugins=[PytestCollectionPlugin()] + [tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()] ) except Exception as e: # noqa: BLE001 print(f"Failed to collect tests: {e!s}") # noqa: T201 diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 3d338abc8..1366fcc0b 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -15,7 +15,7 @@ from enum import Enum, IntEnum from pathlib import Path from re import Pattern -from typing import Annotated, Any, Optional, Union, cast +from typing import Annotated, Optional, cast from jedi.api.classes import Name from pydantic import AfterValidator, BaseModel, ConfigDict, Field @@ -58,15 +58,19 @@ class FunctionSource: def __eq__(self, other: object) -> bool: if not isinstance(other, FunctionSource): return False - return (self.file_path == other.file_path and - self.qualified_name == other.qualified_name and - self.fully_qualified_name == other.fully_qualified_name and - self.only_function_name == other.only_function_name and - self.source_code == other.source_code) + return ( + self.file_path == other.file_path + and self.qualified_name == other.qualified_name + and self.fully_qualified_name == other.fully_qualified_name + and self.only_function_name == other.only_function_name + and self.source_code == other.source_code + ) def __hash__(self) -> int: - return hash((self.file_path, self.qualified_name, self.fully_qualified_name, - self.only_function_name, self.source_code)) + return hash( + (self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code) + ) + class BestOptimization(BaseModel): candidate: OptimizedCandidate @@ -100,7 +104,8 @@ class CodeOptimizationContext(BaseModel): read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] + class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE" @@ -293,6 +298,7 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt status=CoverageStatus.NOT_FOUND, ) + @dataclass class FunctionCoverage: """Represents the coverage data for a specific function in a source file.""" @@ -540,4 +546,4 @@ def __eq__(self, other: object) -> bool: sys.setrecursionlimit(original_recursion_limit) return False sys.setrecursionlimit(original_recursion_limit) - return True \ No newline at end of file + return True diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index 4f1c892bc..3c28a92db 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -1,10 +1,9 @@ from __future__ import annotations from jedi.api.classes import Name -from codeflash.code_utils.code_utils import ( - get_qualified_name, -) +from codeflash.code_utils.code_utils import get_qualified_name + def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool: """Check if the given name belongs to the specified method.""" @@ -40,4 +39,4 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b return get_qualified_name(name.module_name, name.full_name) == qualified_function_name return False except ValueError: - return False \ No newline at end of file + return False diff --git a/codeflash/tracing/replay_test.py b/codeflash/tracing/replay_test.py index 62d9dbbe6..eca1e50ef 100644 --- a/codeflash/tracing/replay_test.py +++ b/codeflash/tracing/replay_test.py @@ -3,7 +3,7 @@ import sqlite3 import textwrap from collections.abc import Generator -from typing import Any, List, Optional +from typing import Any, Optional from codeflash.discovery.functions_to_optimize import FunctionProperties, inspect_top_level_functions_or_methods from codeflash.tracing.tracing_utils import FunctionModules diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index ef4f8fadd..60372fcb4 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -50,6 +50,7 @@ except ImportError: HAS_TORCH = False + def comparator(orig: Any, new: Any, superset_obj=False) -> bool: """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" try: @@ -181,7 +182,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: return False return torch.allclose(orig, new, equal_nan=True) - if HAS_PYRSISTENT and isinstance( orig, ( diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index 2ef03cf5f..c9044a44d 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -26,13 +26,17 @@ class CoverageUtils: @staticmethod def load_from_sqlite_database( - database_path: Path, config_path: Path, function_name: str, code_context: CodeOptimizationContext, source_code_path: Path + database_path: Path, + config_path: Path, + function_name: str, + code_context: CodeOptimizationContext, + source_code_path: Path, ) -> CoverageData: """Load coverage data from an SQLite database, mimicking the behavior of load_from_coverage_file.""" from coverage import Coverage from coverage.jsonreport import JsonReporter - cov = Coverage(data_file=database_path,config_file=config_path, data_suffix=True, auto_data=True, branch=True) + cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True) if not database_path.stat().st_size or not database_path.exists(): logger.debug(f"Coverage database {database_path} is empty or does not exist") @@ -226,4 +230,3 @@ def grab_dependent_function_from_coverage_data( executed_branches=[], unexecuted_branches=[], ) - diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index a0afd7254..924e2876a 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -40,7 +40,6 @@ def parse_func(file_path: Path) -> XMLParser: cleaner_re = re.compile(r"!######.*?######!|-+\s*Captured\s+(Log|Out)\s*-+\n?") - def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults: test_results = TestResults() if not file_location.exists(): diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 949b6569e..d4b3f15b4 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -18,6 +18,7 @@ BEHAVIORAL_BLOCKLISTED_PLUGINS = ["benchmark"] BENCHMARKING_BLOCKLISTED_PLUGINS = ["codspeed", "cov", "benchmark", "profiling"] + def execute_test_subprocess( cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600 ) -> subprocess.CompletedProcess: @@ -81,7 +82,14 @@ def run_behavioral_tests( ) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, # then the current run will be appended to the previous data, which skews the results logger.debug(cov_erase) - coverage_cmd = [SAFE_SYS_EXECUTABLE, "-m", "coverage", "run", f"--rcfile={coverage_config_file.as_posix()}", "-m"] + coverage_cmd = [ + SAFE_SYS_EXECUTABLE, + "-m", + "coverage", + "run", + f"--rcfile={coverage_config_file.as_posix()}", + "-m", + ] if pytest_cmd == "pytest": coverage_cmd.extend(["pytest"]) @@ -90,7 +98,10 @@ def run_behavioral_tests( blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"] results = execute_test_subprocess( - coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files, cwd=cwd, env=pytest_test_env, timeout=600 + coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, ) logger.debug( f"Result return code: {results.returncode}, " @@ -123,7 +134,12 @@ def run_behavioral_tests( msg = f"Unsupported test framework: {test_framework}" raise ValueError(msg) - return result_file_path, results, coverage_database_file if enable_coverage else None, coverage_config_file if enable_coverage else None + return ( + result_file_path, + results, + coverage_database_file if enable_coverage else None, + coverage_config_file if enable_coverage else None, + ) def run_benchmarking_tests( diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index c3a7e0718..79f1b9656 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -6,8 +6,6 @@ from pydantic.dataclasses import dataclass -from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE - def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path: assert test_type in ["unit", "inspired", "replay", "perf"] @@ -76,4 +74,4 @@ class TestConfig: # tests_project_rootdir corresponds to pytest rootdir, # or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path) concolic_test_root_dir: Optional[Path] = None - pytest_cmd: str = "pytest" \ No newline at end of file + pytest_cmd: str = "pytest" diff --git a/pyproject.toml b/pyproject.toml index 2e71f2a0a..e8aa01d75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,7 +151,7 @@ warn_required_dynamic_aliases = true line-length = 120 fix = true show-fixes = true -exclude = ["code_to_optimize/", "pie_test_set/"] +exclude = ["code_to_optimize/", "pie_test_set/", "tests/"] [tool.ruff.lint] select = ["ALL"]