From de6b7aae1c7c1128056ffc985f0fb28e76ccb116 Mon Sep 17 00:00:00 2001 From: Saga4 Date: Wed, 19 Feb 2025 05:56:26 +0530 Subject: [PATCH 01/20] codeflash error handle when remote is empty --- codeflash/cli_cmds/cmd_init.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 2a9b977e2..48f967e88 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -220,19 +220,26 @@ def collect_setup_info() -> SetupInfo: carousel=True, ) + git_remote = "" try: repo = Repo(str(module_root), search_parent_directories=True) git_remotes = get_git_remotes(repo) - if len(git_remotes) > 1: - git_remote = inquirer_wrapper( - inquirer.list_input, - message="What git remote do you want Codeflash to use for new Pull Requests? ", - choices=git_remotes, - default="origin", - carousel=True, - ) + if git_remotes: # Only proceed if there are remotes + if len(git_remotes) > 1: + git_remote = inquirer_wrapper( + inquirer.list_input, + message="What git remote do you want Codeflash to use for new Pull Requests? ", + choices=git_remotes, + default="origin", + carousel=True, + ) + else: + git_remote = git_remotes[0] else: - git_remote = git_remotes[0] + click.echo( + "No git remotes found. You can still use Codeflash locally, but you'll need to set up a remote " + "repository to use GitHub features." + ) except InvalidGitRepositoryError: git_remote = "" From fa9911dafa3c7a402ba38a053c8481a8e723e852 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 18 Feb 2025 21:40:09 -0500 Subject: [PATCH 02/20] check formatter during init --- codeflash/cli_cmds/cmd_init.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 2a9b977e2..6c2623dd1 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -587,6 +587,16 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: ) elif formatter == "don't use a formatter": formatter_cmds.append("disabled") + if formatter in ["black", "ruff"]: + try: + result = subprocess.run([formatter], capture_output=True, check=False) + click.echo(f"✅ Formatter exists on system") + click.echo() + except FileNotFoundError as e: + click.echo(f"⚠️ Formatter not found: {formatter}") + click.echo() + # Not throwing an exception, letting the program proceed even though the formatter was not found, putting it on the user to install it later + # raise e from None codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) From fe71b4f4ed71a3effab8fb88b22b66f9cffdf67a Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 18 Feb 2025 19:20:56 -0800 Subject: [PATCH 03/20] Update cmd_init.py --- codeflash/cli_cmds/cmd_init.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 6c2623dd1..ad59f4b7a 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -590,13 +590,8 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: if formatter in ["black", "ruff"]: try: result = subprocess.run([formatter], capture_output=True, check=False) - click.echo(f"✅ Formatter exists on system") - click.echo() except FileNotFoundError as e: click.echo(f"⚠️ Formatter not found: {formatter}") - click.echo() - # Not throwing an exception, letting the program proceed even though the formatter was not found, putting it on the user to install it later - # raise e from None codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) From f044c50f3d53b509049e0edcf7503773957b637e Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 18 Feb 2025 22:31:50 -0500 Subject: [PATCH 04/20] suggest installing via pip --- codeflash/cli_cmds/cmd_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index ad59f4b7a..ecf07f2c9 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -591,7 +591,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: try: result = subprocess.run([formatter], capture_output=True, check=False) except FileNotFoundError as e: - click.echo(f"⚠️ Formatter not found: {formatter}") + click.echo(f"⚠️ Formatter not found: {formatter}, please install via \'pip install {formatter}\'") codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) From 895c4477791e96b276dc592b648efb58751658f3 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 18 Feb 2025 19:51:29 -0800 Subject: [PATCH 05/20] Update cmd_init.py --- codeflash/cli_cmds/cmd_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index ecf07f2c9..63546c0d1 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -591,7 +591,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: try: result = subprocess.run([formatter], capture_output=True, check=False) except FileNotFoundError as e: - click.echo(f"⚠️ Formatter not found: {formatter}, please install via \'pip install {formatter}\'") + click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed") codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) From 35103122caf8be14e70edf510542fa72eb014ac5 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 19 Feb 2025 18:19:00 -0500 Subject: [PATCH 06/20] add auditwall.py --- codeflash/verification/codeflash_auditwall.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 codeflash/verification/codeflash_auditwall.py diff --git a/codeflash/verification/codeflash_auditwall.py b/codeflash/verification/codeflash_auditwall.py new file mode 100644 index 000000000..73c27e188 --- /dev/null +++ b/codeflash/verification/codeflash_auditwall.py @@ -0,0 +1,26 @@ +import ast + + +class AuditWallTransformer(ast.NodeTransformer): + def visit_Module(self, node): + last_import_index = -1 + for i, body_node in enumerate(node.body): + if isinstance(body_node, (ast.Import, ast.ImportFrom)): + last_import_index = i + + new_import = ast.ImportFrom(module="crosshair.auditwall", names=[ast.alias(name="engage_auditwall")], level=0) + function_call = ast.Expr( + value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[]) + ) + + node.body.insert(last_import_index + 1, new_import) + node.body.insert(last_import_index + 2, function_call) + + return node + + +def transform_code(source_code: str) -> str: + tree = ast.parse(source_code) + transformer = AuditWallTransformer() + new_tree = transformer.visit(tree) + return ast.unparse(new_tree) From 3f524c2be6bf2ce5aa49617874810d8ac316e008 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 19 Feb 2025 18:59:14 -0500 Subject: [PATCH 07/20] first pass --- codeflash/optimization/function_optimizer.py | 25 ++-- codeflash/verification/test_runner.py | 150 +++++++++++-------- 2 files changed, 103 insertions(+), 72 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7b067a094..d8a63d6a3 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,8 +3,11 @@ import ast import concurrent.futures import os +import re +import shlex import shutil import subprocess +import tempfile import time import uuid from collections import defaultdict @@ -13,6 +16,7 @@ import isort import libcst as cst +from crosshair.auditwall import SideEffectDetected from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -29,12 +33,14 @@ get_run_tmp_file, module_name_from_file_path, ) +from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, N_CANDIDATES, N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) +from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -63,12 +69,13 @@ from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.explanation import Explanation from codeflash.telemetry.posthog_cf import ph +from codeflash.verification.codeflash_auditwall import transform_code from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.test_results import TestResults, TestType -from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests +from codeflash.verification.test_runner import execute_test_subprocess, run_behavioral_tests, run_benchmarking_tests from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -149,12 +156,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.args.project_root, ) + generated_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" ) for test_index in range(N_TESTS_TO_GENERATE) ] + generated_perf_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="perf" @@ -844,6 +853,8 @@ def establish_original_code_baseline( enable_coverage=test_framework == "pytest", code_context=code_context, ) + except SideEffectDetected as e: + return Failure(f"Side effect detected in original code: {e}") finally: # Remove codeflash capture self.write_code_and_helpers( @@ -855,9 +866,7 @@ def establish_original_code_baseline( ) console.rule() return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") - if not coverage_critic( - coverage_results, self.args.test_framework - ): + if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": benchmarking_results, _ = self.run_and_parse_tests( @@ -898,7 +907,6 @@ def establish_original_code_baseline( ) console.rule() - total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index functions_to_remove = [ result.id.test_function_name @@ -1097,13 +1105,13 @@ def run_and_parse_tests( raise ValueError(f"Unexpected testing type: {testing_type}") except subprocess.TimeoutExpired: logger.exception( - f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error' + f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error" ) return TestResults(), None if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR: logger.debug( - f'Nonzero return code {run_result.returncode} when running tests in ' - f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n' + f"Nonzero return code {run_result.returncode} when running tests in " + f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n" f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) @@ -1149,4 +1157,3 @@ def generate_and_instrument_tests( zip(generated_test_paths, generated_perf_test_paths) ) ] - diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 46203e65a..132f42bac 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -1,16 +1,21 @@ from __future__ import annotations +import re import shlex import subprocess +import tempfile from pathlib import Path from typing import TYPE_CHECKING +from crosshair.auditwall import SideEffectDetected + from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.models.models import TestFiles +from codeflash.verification.codeflash_auditwall import transform_code from codeflash.verification.test_results import TestType if TYPE_CHECKING: @@ -36,78 +41,97 @@ def run_behavioral_tests( pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME, enable_coverage: bool = False, ) -> tuple[Path, subprocess.CompletedProcess, Path | None]: - if test_framework == "pytest": - test_files: list[str] = [] - for file in test_paths.test_files: - if file.test_type == TestType.REPLAY_TEST: - # TODO: Does this work for unittest framework? - test_files.extend( - [ - str(file.instrumented_behavior_file_path) + "::" + test.test_function - for test in file.tests_in_file - ] - ) - else: - test_files.append(str(file.instrumented_behavior_file_path)) - test_files = list(set(test_files)) # remove multiple calls in the same test function - pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX) - - common_pytest_args = [ - "--capture=tee-sys", - f"--timeout={pytest_timeout}", - "-q", - "--codeflash_loops_scope=session", - "--codeflash_min_loops=1", - "--codeflash_max_loops=1", - f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO :This is unnecessary, update the plugin to not ask for this - ] - - result_file_path = get_run_tmp_file(Path("pytest_results.xml")) - result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] - - pytest_test_env = test_env.copy() - pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + if test_framework not in ["pytest", "unittest"]: + raise ValueError(f"Unsupported test framework: {test_framework}") - if enable_coverage: - coverage_database_file, coveragercfile = prepare_coverage_files() - - cov_erase = execute_test_subprocess( - shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env - ) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, - # then the current run will be appended to the previous data, which skews the results - logger.debug(cov_erase) - - results = execute_test_subprocess( - shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m") - + pytest_cmd_list - + common_pytest_args - + result_args - + test_files, - cwd=cwd, - env=pytest_test_env, - timeout=600, + test_files: list[str] = [] + for file in test_paths.test_files: + if file.test_type == TestType.REPLAY_TEST: + # TODO: Does this work for unittest framework? + test_files.extend( + [str(file.instrumented_behavior_file_path) + "::" + test.test_function for test in file.tests_in_file] ) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") else: - results = execute_test_subprocess( - pytest_cmd_list + common_pytest_args + result_args + test_files, - cwd=cwd, - env=pytest_test_env, - timeout=600, # TODO: Make this dynamic + test_files.append(str(file.instrumented_behavior_file_path)) + + source_code = next((file.original_source for file in test_paths.test_files if file.original_source), None) + if not source_code: + raise ValueError("No source code found for auditing") + + audit_code = transform_code(source_code) + pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX) + common_pytest_args = [ + "--capture=tee-sys", + f"--timeout={pytest_timeout}", + "-q", + "--codeflash_loops_scope=session", + "--codeflash_min_loops=1", + "--codeflash_max_loops=1", + f"--codeflash_seconds={pytest_target_runtime_seconds}", + "-p", + "no:cacheprovider", + ] + + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + + with tempfile.TemporaryDirectory( + dir=Path(test_paths.test_files[0].instrumented_behavior_file_path).parent + ) as temp_dir: + audited_file_path = Path(temp_dir) / "audited_code.py" + audited_file_path.write_text(audit_code, encoding="utf8") + + auditing_res = execute_test_subprocess( + pytest_cmd_list + common_pytest_args + [audited_file_path.as_posix()], + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + + if auditing_res.returncode != 0: + line_co = next( + ( + line + for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines() + if "crosshair.auditwall.SideEffectDetected" in line + ), + None, ) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") - elif test_framework == "unittest": + + if line_co: + match = re.search(r"crosshair\.auditwall\.SideEffectDetected: A(.*) operation was detected\.", line_co) + if match: + msg = match.group(1) + raise SideEffectDetected(msg) + logger.debug(auditing_res.stderr) + logger.debug(auditing_res.stdout) + + if test_framework == "pytest": + coverage_database_file, coveragercfile = prepare_coverage_files() + cov_erase = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env + ) + logger.debug(cov_erase) + + results = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m") + + pytest_cmd_list + + common_pytest_args + + result_args + + list(set(test_files)), # remove duplicates + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + else: # unittest if enable_coverage: raise ValueError("Coverage is not supported yet for unittest framework") test_env["CODEFLASH_LOOP_INDEX"] = "1" test_files = [file.instrumented_behavior_file_path for file in test_paths.test_files] result_file_path, results = run_unittest_tests(verbose, test_files, test_env, cwd) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") - else: - raise ValueError(f"Unsupported test framework: {test_framework}") return result_file_path, results, coverage_database_file if enable_coverage else None From 0cab5116b452a0b19b0e05c05269a837021f705f Mon Sep 17 00:00:00 2001 From: John Lyu Date: Fri, 21 Feb 2025 19:24:05 +0800 Subject: [PATCH 08/20] Use python >= 3.9 instead of ^3.9 See details at https://github.com/strawberry-graphql/strawberry/pull/3789 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 62587fa9d..27ebf6c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ exclude = [ # Versions here the minimum required versions for the project. These should be as loose as possible. [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.9" unidiff = ">=0.7.4" pytest = ">=7.0.0" gitpython = ">=3.1.31" From 414ac44c6837b6a3ab1d69b050b61b9e0ace18fe Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 21 Feb 2025 16:11:14 -0800 Subject: [PATCH 09/20] second pass --- codeflash/optimization/function_optimizer.py | 14 +- codeflash/verification/_auditwall.py | 184 ++++++++++++++++++ codeflash/verification/codeflash_auditwall.py | 4 +- codeflash/verification/test_runner.py | 11 +- tests/test_codeflash_capture.py | 5 + 5 files changed, 200 insertions(+), 18 deletions(-) create mode 100644 codeflash/verification/_auditwall.py diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d8a63d6a3..27b00ae0d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,11 +3,8 @@ import ast import concurrent.futures import os -import re -import shlex import shutil import subprocess -import tempfile import time import uuid from collections import defaultdict @@ -16,7 +13,6 @@ import isort import libcst as cst -from crosshair.auditwall import SideEffectDetected from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -33,14 +29,12 @@ get_run_tmp_file, module_name_from_file_path, ) -from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, N_CANDIDATES, N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -69,13 +63,13 @@ from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.explanation import Explanation from codeflash.telemetry.posthog_cf import ph -from codeflash.verification.codeflash_auditwall import transform_code +from codeflash.verification._auditwall import SideEffectDetectedError from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.test_results import TestResults, TestType -from codeflash.verification.test_runner import execute_test_subprocess, run_behavioral_tests, run_benchmarking_tests +from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -853,8 +847,8 @@ def establish_original_code_baseline( enable_coverage=test_framework == "pytest", code_context=code_context, ) - except SideEffectDetected as e: - return Failure(f"Side effect detected in original code: {e}") + except SideEffectDetectedError as e: + return Failure(f"Side effect detected in original code: {e}, skipping optimization.") finally: # Remove codeflash capture self.write_code_and_helpers( diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py new file mode 100644 index 000000000..2ea24d09f --- /dev/null +++ b/codeflash/verification/_auditwall.py @@ -0,0 +1,184 @@ +import importlib +import os +import sys +import traceback +from collections.abc import Generator, Iterable +from contextlib import contextmanager, suppress +from types import ModuleType +from typing import Callable, Optional + + +class SideEffectDetectedError(Exception): + pass + + +_BLOCKED_OPEN_FLAGS = os.O_WRONLY | os.O_RDWR | os.O_APPEND | os.O_CREAT | os.O_EXCL | os.O_TRUNC + + +def accept(event: str, args: tuple) -> None: + pass + + +def reject(event: str, args: tuple) -> None: + msg = f'codeflash has detected: {event}{args}".' + raise SideEffectDetectedError(msg) + + +def inside_module(modules: Iterable[ModuleType]) -> bool: + files = {m.__file__ for m in modules} + return any(frame.f_code.co_filename in files for frame, lineno in traceback.walk_stack(None)) + + +def check_open(event: str, args: tuple) -> None: + (filename_or_descriptor, mode, flags) = args + if filename_or_descriptor in ("/dev/null", "nul"): + # (no-op writes on unix/windows) + return + if flags & _BLOCKED_OPEN_FLAGS: + msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." + raise SideEffectDetectedError(msg) + + +def check_msvcrt_open(event: str, args: tuple) -> None: + print(args) + (handle, flags) = args + if flags & _BLOCKED_OPEN_FLAGS: + msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." + raise SideEffectDetectedError(msg) + + +_MODULES_THAT_CAN_POPEN: Optional[set[ModuleType]] = None + + +def modules_with_allowed_popen(): + global _MODULES_THAT_CAN_POPEN + if _MODULES_THAT_CAN_POPEN is None: + allowed_module_names = ("_aix_support", "ctypes", "platform", "uuid") + _MODULES_THAT_CAN_POPEN = set() + for module_name in allowed_module_names: + with suppress(ImportError): + _MODULES_THAT_CAN_POPEN.add(importlib.import_module(module_name)) + return _MODULES_THAT_CAN_POPEN + + +def check_subprocess(event: str, args: tuple) -> None: + if not inside_module(modules_with_allowed_popen()): + reject(event, args) + + +def check_sqlite_connect(event: str, args: tuple) -> None: + if "codeflash_" in args[0]: + accept(event, args) + else: + reject(event, args) + + +_SPECIAL_HANDLERS = { + "open": check_open, + "subprocess.Popen": check_subprocess, + "msvcrt.open_osfhandle": check_msvcrt_open, + "sqlite3.connect": check_sqlite_connect, +} + + +def make_handler(event: str) -> Callable[[str, tuple], None]: + special_handler = _SPECIAL_HANDLERS.get(event) + if special_handler: + return special_handler + # Block certain events + if event in ( + "winreg.CreateKey", + "winreg.DeleteKey", + "winreg.DeleteValue", + "winreg.SaveKey", + "winreg.SetValue", + "winreg.DisableReflectionKey", + "winreg.EnableReflectionKey", + ): + return reject + # Allow certain events. + if event in ( + # These seem not terribly dangerous to allow: + "os.putenv", + "os.unsetenv", + "msvcrt.heapmin", + "msvcrt.kbhit", + # These involve I/O, but are hopefully non-destructive: + "glob.glob", + "msvcrt.get_osfhandle", + "msvcrt.setmode", + "os.listdir", # (important for Python's importer) + "os.scandir", # (important for Python's importer) + "os.chdir", + "os.fwalk", + "os.getxattr", + "os.listxattr", + "os.walk", + "pathlib.Path.glob", + "socket.gethostbyname", # (FastAPI TestClient uses this) + "socket.__new__", # (FastAPI TestClient uses this) + "socket.bind", # pygls's asyncio needs this on windows + "socket.connect", # pygls's asyncio needs this on windows + ): + return accept + # Block groups of events. + event_prefix = event.split(".", 1)[0] + if event_prefix in ( + "os", + "fcntl", + "ftplib", + "glob", + "imaplib", + "msvcrt", + "nntplib", + "os", + "pathlib", + "poplib", + "shutil", + "smtplib", + "socket", + "sqlite3", + "subprocess", + "telnetlib", + "urllib", + "webbrowser", + ): + return reject + # Allow other events. + return accept + + +_HANDLERS: dict[str, Callable[[str, tuple], None]] = {} +_ENABLED = True + + +def audithook(event: str, args: tuple) -> None: + if not _ENABLED: + return + handler = _HANDLERS.get(event) + if handler is None: + handler = make_handler(event) + _HANDLERS[event] = handler + handler(event, args) + + +@contextmanager +def opened_auditwall() -> Generator: + global _ENABLED + assert _ENABLED + _ENABLED = False + try: + yield + finally: + _ENABLED = True + + +def engage_auditwall() -> None: + sys.dont_write_bytecode = True # disable .pyc file writing + sys.addaudithook(audithook) + + +def disable_auditwall() -> None: + global _ENABLED + assert _ENABLED + _ENABLED = False diff --git a/codeflash/verification/codeflash_auditwall.py b/codeflash/verification/codeflash_auditwall.py index 73c27e188..b99321a04 100644 --- a/codeflash/verification/codeflash_auditwall.py +++ b/codeflash/verification/codeflash_auditwall.py @@ -8,7 +8,9 @@ def visit_Module(self, node): if isinstance(body_node, (ast.Import, ast.ImportFrom)): last_import_index = i - new_import = ast.ImportFrom(module="crosshair.auditwall", names=[ast.alias(name="engage_auditwall")], level=0) + new_import = ast.ImportFrom( + module="codeflash.verification._auditwall", names=[ast.alias(name="engage_auditwall")], level=0 + ) function_call = ast.Expr( value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[]) ) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 132f42bac..229f0be64 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -7,14 +7,13 @@ from pathlib import Path from typing import TYPE_CHECKING -from crosshair.auditwall import SideEffectDetected - from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.models.models import TestFiles +from codeflash.verification._auditwall import SideEffectDetectedError from codeflash.verification.codeflash_auditwall import transform_code from codeflash.verification.test_results import TestType @@ -90,22 +89,20 @@ def run_behavioral_tests( env=pytest_test_env, timeout=600, ) - if auditing_res.returncode != 0: line_co = next( ( line for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines() - if "crosshair.auditwall.SideEffectDetected" in line + if "codeflash.verification._auditwall.SideEffectDetectedError" in line ), None, ) - if line_co: - match = re.search(r"crosshair\.auditwall\.SideEffectDetected: A(.*) operation was detected\.", line_co) + match = re.search(r"codeflash has detected: (.+).", line_co) if match: msg = match.group(1) - raise SideEffectDetected(msg) + raise SideEffectDetectedError(msg) logger.debug(auditing_res.stderr) logger.debug(auditing_res.stdout) diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 83b1efd2b..7cfd7f782 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -457,6 +457,7 @@ def __init__(self, x=2): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -568,6 +569,7 @@ def __init__(self, *args, **kwargs): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -681,6 +683,7 @@ def __init__(self, x=2): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -831,6 +834,7 @@ def another_helper(self): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -967,6 +971,7 @@ def another_helper(self): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) From 9139c43f641f6a12085f4b2bc5d6ed45890ea365 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 21 Feb 2025 17:12:50 -0800 Subject: [PATCH 10/20] fix tests --- codeflash/cli_cmds/console.py | 4 +++- codeflash/verification/_auditwall.py | 3 ++- tests/test_instrument_all_and_run.py | 3 +++ tests/test_instrument_tests.py | 8 ++++++++ ...test_instrumentation_run_results_aiservice.py | 2 ++ tests/test_test_runner.py | 16 ++++++++++++++-- 6 files changed, 32 insertions(+), 4 deletions(-) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 9661e9509..3ef0f2eca 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from itertools import cycle -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from rich.console import Console from rich.logging import RichHandler @@ -13,6 +13,8 @@ from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT if TYPE_CHECKING: + from collections.abc import Generator + from rich.progress import TaskID DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py index 2ea24d09f..d834ba05a 100644 --- a/codeflash/verification/_auditwall.py +++ b/codeflash/verification/_auditwall.py @@ -67,7 +67,7 @@ def check_subprocess(event: str, args: tuple) -> None: def check_sqlite_connect(event: str, args: tuple) -> None: - if "codeflash_" in args[0]: + if any("codeflash_" in arg for arg in args): accept(event, args) else: reject(event, args) @@ -78,6 +78,7 @@ def check_sqlite_connect(event: str, args: tuple) -> None: "subprocess.Popen": check_subprocess, "msvcrt.open_osfhandle": check_msvcrt_open, "sqlite3.connect": check_sqlite_connect, + "sqlite3.connect/handle": check_sqlite_connect, } diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 643d4bde7..8cfc06190 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -156,6 +156,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -328,6 +329,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -423,6 +425,7 @@ def sorter(self, arr): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index bf7373522..311774845 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -413,6 +413,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -608,6 +609,7 @@ def test_sort_parametrized(input, expected_output): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -847,6 +849,7 @@ def test_sort_parametrized_loop(input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1156,6 +1159,7 @@ def test_sort(): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1425,6 +1429,7 @@ def test_sort(self): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1672,6 +1677,7 @@ def test_sort(self, input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1923,6 +1929,7 @@ def test_sort(self): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -2172,6 +2179,7 @@ def test_sort(self, input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index ee237cfca..62226b519 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -162,6 +162,7 @@ def test_single_element_list(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=fto_path.read_text("utf-8"), ) ] ) @@ -298,6 +299,7 @@ def test_single_element_list(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=fto_path.read_text("utf-8"), ) ] ) diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index 60a4be70f..09a7ffb32 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -36,7 +36,13 @@ def test_sort(self): with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)] + test_files=[ + TestFile( + instrumented_behavior_file_path=Path(fp.name), + test_type=TestType.EXISTING_UNIT_TEST, + original_source=code, + ) + ] ) fp.write(code.encode("utf-8")) fp.flush() @@ -80,7 +86,13 @@ def test_sort(): with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)] + test_files=[ + TestFile( + instrumented_behavior_file_path=Path(fp.name), + test_type=TestType.EXISTING_UNIT_TEST, + original_source=code, + ) + ] ) fp.write(code.encode("utf-8")) fp.flush() From a6da525b9cac4f1f071a9aa7818a4de21f2cd441 Mon Sep 17 00:00:00 2001 From: davidgirdwood1 <162387981+davidgirdwood1@users.noreply.github.com> Date: Mon, 24 Feb 2025 15:25:54 -0800 Subject: [PATCH 11/20] Don't display result if INIT_STATE_TEST --- codeflash/github/PrComment.py | 12 ++++++++---- codeflash/verification/test_results.py | 3 +++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index 76a19f5e7..a9b11b5c2 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -20,6 +20,13 @@ class PrComment: winning_benchmarking_test_results: TestResults def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: + + report_table = { + test_type.to_name(): result + for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() + if test_type.should_display_result() + } + return { "optimization_explanation": self.optimization_explanation, "best_runtime": humanize_runtime(self.best_runtime), @@ -29,10 +36,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, "loop_count": self.winning_benchmarking_test_results.number_of_loops(), - "report_table": { - test_type.to_name(): result - for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() - }, + "report_table": report_table } diff --git a/codeflash/verification/test_results.py b/codeflash/verification/test_results.py index 28d8bfc0d..c2c99d3cf 100644 --- a/codeflash/verification/test_results.py +++ b/codeflash/verification/test_results.py @@ -30,6 +30,9 @@ class TestType(Enum): CONCOLIC_COVERAGE_TEST = 5 INIT_STATE_TEST = 6 + def should_display(self) -> bool: + return self != TestType.INIT_STATE_TEST + def to_name(self) -> str: if self == TestType.INIT_STATE_TEST: return "" From 488d4fac00df885cdd3ef7c2f13662bc5b5a93b3 Mon Sep 17 00:00:00 2001 From: davidgirdwood1 <162387981+davidgirdwood1@users.noreply.github.com> Date: Mon, 24 Feb 2025 16:10:50 -0800 Subject: [PATCH 12/20] Use the to_name() --- codeflash/github/PrComment.py | 2 +- codeflash/verification/test_results.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index a9b11b5c2..a6f6aa892 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -24,7 +24,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: report_table = { test_type.to_name(): result for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() - if test_type.should_display_result() + if test_type.to_name() } return { diff --git a/codeflash/verification/test_results.py b/codeflash/verification/test_results.py index c2c99d3cf..28d8bfc0d 100644 --- a/codeflash/verification/test_results.py +++ b/codeflash/verification/test_results.py @@ -30,9 +30,6 @@ class TestType(Enum): CONCOLIC_COVERAGE_TEST = 5 INIT_STATE_TEST = 6 - def should_display(self) -> bool: - return self != TestType.INIT_STATE_TEST - def to_name(self) -> str: if self == TestType.INIT_STATE_TEST: return "" From 02cfca07e3705e0b96b64ade06e02d4944abce1b Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Tue, 25 Feb 2025 18:00:10 -0800 Subject: [PATCH 13/20] release/v0.10.0 --- codeflash/LICENSE | 4 ++-- codeflash/version.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/codeflash/LICENSE b/codeflash/LICENSE index 8b94a373d..d32df80d3 100644 --- a/codeflash/LICENSE +++ b/codeflash/LICENSE @@ -3,7 +3,7 @@ Business Source License 1.1 Parameters Licensor: CodeFlash Inc. -Licensed Work: Codeflash Client version 0.9.x +Licensed Work: Codeflash Client version 0.10.x The Licensed Work is (c) 2024 CodeFlash Inc. Additional Use Grant: None. Production use of the Licensed Work is only permitted @@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte Platform. Please visit codeflash.ai for further information. -Change Date: 2029-01-06 +Change Date: 2029-02-25 Change License: MIT diff --git a/codeflash/version.py b/codeflash/version.py index b29fd20bd..55232158e 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,3 +1,3 @@ # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`. -__version__ = "0.9.2" -__version_tuple__ = (0, 9, 2) +__version__ = "0.10.0" +__version_tuple__ = (0, 10, 0) From cc1a0b934b43046218a9cc645ec6dfabcfd4b44f Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 21 Feb 2025 17:18:46 -0800 Subject: [PATCH 14/20] fix tests failing --- .github/workflows/unit-tests.yaml | 3 ++- codeflash/verification/_auditwall.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index a1e7da8ea..cf0e30275 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -32,7 +32,8 @@ jobs: run: uvx poetry install --with dev - name: Unit tests - run: uvx poetry run pytest tests/ --cov --cov-report=xml + run: uvx poetry run pytest tests/ --cov --cov-report=xml -vv + # run: uvx poetry run pytest tests/ -vv - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py index d834ba05a..7a4b4a75f 100644 --- a/codeflash/verification/_auditwall.py +++ b/codeflash/verification/_auditwall.py @@ -19,6 +19,9 @@ def accept(event: str, args: tuple) -> None: pass +args_allow_list = {".coverage", "matplotlib.rc", "codeflash"} + + def reject(event: str, args: tuple) -> None: msg = f'codeflash has detected: {event}{args}".' raise SideEffectDetectedError(msg) @@ -40,7 +43,6 @@ def check_open(event: str, args: tuple) -> None: def check_msvcrt_open(event: str, args: tuple) -> None: - print(args) (handle, flags) = args if flags & _BLOCKED_OPEN_FLAGS: msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." @@ -66,8 +68,18 @@ def check_subprocess(event: str, args: tuple) -> None: reject(event, args) +def handle_os_remove(event: str, args: tuple) -> None: + filename = str(args[0]) + if any(pattern in filename for pattern in args_allow_list): + accept(event, args) + else: + reject(event, args) + + def check_sqlite_connect(event: str, args: tuple) -> None: - if any("codeflash_" in arg for arg in args): + if ( + event == "sqlite3.connect" and any(pattern in str(args[0]) for pattern in args_allow_list) + ) or event == "sqlite3.connect/handle": accept(event, args) else: reject(event, args) @@ -79,6 +91,7 @@ def check_sqlite_connect(event: str, args: tuple) -> None: "msvcrt.open_osfhandle": check_msvcrt_open, "sqlite3.connect": check_sqlite_connect, "sqlite3.connect/handle": check_sqlite_connect, + "os.remove": handle_os_remove, } From 68391613621465554d42505af4a17dfb96aee80b Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 26 Feb 2025 14:29:37 -0800 Subject: [PATCH 15/20] Update unit-tests.yaml --- .github/workflows/unit-tests.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index cf0e30275..35e0a7d9b 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -33,7 +33,6 @@ jobs: - name: Unit tests run: uvx poetry run pytest tests/ --cov --cov-report=xml -vv - # run: uvx poetry run pytest tests/ -vv - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 From 7c4778142fc2c38c0776ab3cb110083b757873ea Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 26 Feb 2025 15:26:15 -0800 Subject: [PATCH 16/20] add license notice. --- codeflash/verification/_auditwall.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py index 7a4b4a75f..16d80f6ad 100644 --- a/codeflash/verification/_auditwall.py +++ b/codeflash/verification/_auditwall.py @@ -1,3 +1,18 @@ +# Copyright 2024 CodeFlash Inc. All rights reserved. +# +# Licensed under the Business Source License version 1.1. +# License source can be found in the LICENSE file. +# +# This file includes derived work covered by the following copyright and permission notices: +# +# Copyright Python Software Foundation +# Licensed under the Apache License, Version 2.0 (the "License"). +# http://www.apache.org/licenses/LICENSE-2.0 +# +# The PSF License Agreement +# https://docs.python.org/3/license.html#python-software-foundation-license-version-2 +# +# import importlib import os import sys From 56172db3d8c47388b5345f7bd95d32923cbadafc Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 27 Feb 2025 16:53:10 -0800 Subject: [PATCH 17/20] Update _auditwall.py --- codeflash/verification/_auditwall.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py index 16d80f6ad..1c825e2d3 100644 --- a/codeflash/verification/_auditwall.py +++ b/codeflash/verification/_auditwall.py @@ -153,14 +153,12 @@ def make_handler(event: str) -> Callable[[str, tuple], None]: # Block groups of events. event_prefix = event.split(".", 1)[0] if event_prefix in ( - "os", "fcntl", "ftplib", "glob", "imaplib", "msvcrt", "nntplib", - "os", "pathlib", "poplib", "shutil", @@ -173,6 +171,18 @@ def make_handler(event: str) -> Callable[[str, tuple], None]: "webbrowser", ): return reject + if event_prefix == "os" and event not in [ + "os.putenv", + "os.unsetenv", + "os.listdir", + "os.scandir", + "os.chdir", + "os.fwalk", + "os.getxattr", + "os.listxattr", + "os.walk", + ]: + return reject # Allow other events. return accept From c118a1623b60560b01cc1967c7e5d39e32aae901 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 17 Mar 2025 01:52:09 -0700 Subject: [PATCH 18/20] new way --- codeflash/verification/_auditwall.py | 201 ++---------------- codeflash/verification/codeflash_auditwall.py | 5 +- 2 files changed, 20 insertions(+), 186 deletions(-) diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py index 1c825e2d3..3210ae7d0 100644 --- a/codeflash/verification/_auditwall.py +++ b/codeflash/verification/_auditwall.py @@ -13,79 +13,22 @@ # https://docs.python.org/3/license.html#python-software-foundation-license-version-2 # # -import importlib -import os -import sys -import traceback -from collections.abc import Generator, Iterable -from contextlib import contextmanager, suppress -from types import ModuleType -from typing import Callable, Optional +from auditwall.core import AuditWallConfig, _default_audit_wall, accept, reject -class SideEffectDetectedError(Exception): - pass - -_BLOCKED_OPEN_FLAGS = os.O_WRONLY | os.O_RDWR | os.O_APPEND | os.O_CREAT | os.O_EXCL | os.O_TRUNC - - -def accept(event: str, args: tuple) -> None: - pass - - -args_allow_list = {".coverage", "matplotlib.rc", "codeflash"} - - -def reject(event: str, args: tuple) -> None: - msg = f'codeflash has detected: {event}{args}".' - raise SideEffectDetectedError(msg) - - -def inside_module(modules: Iterable[ModuleType]) -> bool: - files = {m.__file__ for m in modules} - return any(frame.f_code.co_filename in files for frame, lineno in traceback.walk_stack(None)) - - -def check_open(event: str, args: tuple) -> None: - (filename_or_descriptor, mode, flags) = args - if filename_or_descriptor in ("/dev/null", "nul"): - # (no-op writes on unix/windows) - return - if flags & _BLOCKED_OPEN_FLAGS: - msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." - raise SideEffectDetectedError(msg) - - -def check_msvcrt_open(event: str, args: tuple) -> None: - (handle, flags) = args - if flags & _BLOCKED_OPEN_FLAGS: - msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." - raise SideEffectDetectedError(msg) - - -_MODULES_THAT_CAN_POPEN: Optional[set[ModuleType]] = None - - -def modules_with_allowed_popen(): - global _MODULES_THAT_CAN_POPEN - if _MODULES_THAT_CAN_POPEN is None: - allowed_module_names = ("_aix_support", "ctypes", "platform", "uuid") - _MODULES_THAT_CAN_POPEN = set() - for module_name in allowed_module_names: - with suppress(ImportError): - _MODULES_THAT_CAN_POPEN.add(importlib.import_module(module_name)) - return _MODULES_THAT_CAN_POPEN - - -def check_subprocess(event: str, args: tuple) -> None: - if not inside_module(modules_with_allowed_popen()): - reject(event, args) +class CodeflashAuditWallConfig(AuditWallConfig): + def __init__(self) -> None: + super().__init__() + self.allowed_write_paths = {".coverage", "matplotlib.rc", "codeflash"} def handle_os_remove(event: str, args: tuple) -> None: filename = str(args[0]) - if any(pattern in filename for pattern in args_allow_list): + if any( + pattern in filename + for pattern in _default_audit_wall.config.allowed_write_paths + ): accept(event, args) else: reject(event, args) @@ -93,131 +36,23 @@ def handle_os_remove(event: str, args: tuple) -> None: def check_sqlite_connect(event: str, args: tuple) -> None: if ( - event == "sqlite3.connect" and any(pattern in str(args[0]) for pattern in args_allow_list) + event == "sqlite3.connect" + and any( + pattern in str(args[0]) + for pattern in _default_audit_wall.config.allowed_write_paths + ) ) or event == "sqlite3.connect/handle": accept(event, args) else: reject(event, args) -_SPECIAL_HANDLERS = { - "open": check_open, - "subprocess.Popen": check_subprocess, - "msvcrt.open_osfhandle": check_msvcrt_open, +custom_handlers = { + "os.remove": handle_os_remove, "sqlite3.connect": check_sqlite_connect, "sqlite3.connect/handle": check_sqlite_connect, - "os.remove": handle_os_remove, } -def make_handler(event: str) -> Callable[[str, tuple], None]: - special_handler = _SPECIAL_HANDLERS.get(event) - if special_handler: - return special_handler - # Block certain events - if event in ( - "winreg.CreateKey", - "winreg.DeleteKey", - "winreg.DeleteValue", - "winreg.SaveKey", - "winreg.SetValue", - "winreg.DisableReflectionKey", - "winreg.EnableReflectionKey", - ): - return reject - # Allow certain events. - if event in ( - # These seem not terribly dangerous to allow: - "os.putenv", - "os.unsetenv", - "msvcrt.heapmin", - "msvcrt.kbhit", - # These involve I/O, but are hopefully non-destructive: - "glob.glob", - "msvcrt.get_osfhandle", - "msvcrt.setmode", - "os.listdir", # (important for Python's importer) - "os.scandir", # (important for Python's importer) - "os.chdir", - "os.fwalk", - "os.getxattr", - "os.listxattr", - "os.walk", - "pathlib.Path.glob", - "socket.gethostbyname", # (FastAPI TestClient uses this) - "socket.__new__", # (FastAPI TestClient uses this) - "socket.bind", # pygls's asyncio needs this on windows - "socket.connect", # pygls's asyncio needs this on windows - ): - return accept - # Block groups of events. - event_prefix = event.split(".", 1)[0] - if event_prefix in ( - "fcntl", - "ftplib", - "glob", - "imaplib", - "msvcrt", - "nntplib", - "pathlib", - "poplib", - "shutil", - "smtplib", - "socket", - "sqlite3", - "subprocess", - "telnetlib", - "urllib", - "webbrowser", - ): - return reject - if event_prefix == "os" and event not in [ - "os.putenv", - "os.unsetenv", - "os.listdir", - "os.scandir", - "os.chdir", - "os.fwalk", - "os.getxattr", - "os.listxattr", - "os.walk", - ]: - return reject - # Allow other events. - return accept - - -_HANDLERS: dict[str, Callable[[str, tuple], None]] = {} -_ENABLED = True - - -def audithook(event: str, args: tuple) -> None: - if not _ENABLED: - return - handler = _HANDLERS.get(event) - if handler is None: - handler = make_handler(event) - _HANDLERS[event] = handler - handler(event, args) - - -@contextmanager -def opened_auditwall() -> Generator: - global _ENABLED - assert _ENABLED - _ENABLED = False - try: - yield - finally: - _ENABLED = True - - -def engage_auditwall() -> None: - sys.dont_write_bytecode = True # disable .pyc file writing - sys.addaudithook(audithook) - - -def disable_auditwall() -> None: - global _ENABLED - assert _ENABLED - _ENABLED = False +_default_audit_wall.config = CodeflashAuditWallConfig() +_default_audit_wall.config.special_handlers = custom_handlers diff --git a/codeflash/verification/codeflash_auditwall.py b/codeflash/verification/codeflash_auditwall.py index b99321a04..f3e907407 100644 --- a/codeflash/verification/codeflash_auditwall.py +++ b/codeflash/verification/codeflash_auditwall.py @@ -2,14 +2,14 @@ class AuditWallTransformer(ast.NodeTransformer): - def visit_Module(self, node): + def visit_Module(self, node: ast.Module) -> ast.Module: # noqa: N802 last_import_index = -1 for i, body_node in enumerate(node.body): if isinstance(body_node, (ast.Import, ast.ImportFrom)): last_import_index = i new_import = ast.ImportFrom( - module="codeflash.verification._auditwall", names=[ast.alias(name="engage_auditwall")], level=0 + module="auditwall.core", names=[ast.alias(name="engage_auditwall")], level=0 ) function_call = ast.Expr( value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[]) @@ -20,7 +20,6 @@ def visit_Module(self, node): return node - def transform_code(source_code: str) -> str: tree = ast.parse(source_code) transformer = AuditWallTransformer() From 48121e932476254a138d7a4b598047f9a2f3d26c Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 17 Mar 2025 02:00:10 -0700 Subject: [PATCH 19/20] update --- codeflash/optimization/function_optimizer.py | 430 ++++++++++++++----- codeflash/verification/test_runner.py | 5 +- tests/test_code_context_extractor.py | 1 + 3 files changed, 325 insertions(+), 111 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 27b00ae0d..4356a4e63 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ import isort import libcst as cst +from auditwall.core import SideEffectDetected from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -21,7 +22,10 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code +from codeflash.code_utils.code_extractor import ( + add_needed_imports_from_module, + extract_code, +) from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -36,9 +40,15 @@ TOTAL_LOOPING_TIME, ) from codeflash.code_utils.formatter import format_code, sort_imports -from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests -from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast +from codeflash.code_utils.instrument_existing_tests import ( + inject_profiling_into_existing_test, +) +from codeflash.code_utils.remove_generated_tests import ( + remove_functions_from_generated_tests, +) +from codeflash.code_utils.static_analysis import ( + get_first_top_level_function_or_method_ast, +) from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -58,18 +68,29 @@ TestFiles, TestingMode, ) -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions +from codeflash.optimization.function_context import ( + get_constrained_function_context_and_helper_functions, +) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for -from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic +from codeflash.result.critic import ( + coverage_critic, + performance_gain, + quantity_of_tests_critic, + speedup_critic, +) from codeflash.result.explanation import Explanation from codeflash.telemetry.posthog_cf import ph -from codeflash.verification._auditwall import SideEffectDetectedError from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results -from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.verification.instrument_codeflash_capture import ( + instrument_codeflash_capture, +) from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.test_results import TestResults, TestType -from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests +from codeflash.verification.test_runner import ( + run_behavioral_tests, + run_benchmarking_tests, +) from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -97,7 +118,9 @@ def __init__( ) -> None: self.project_root = test_cfg.project_root_path self.test_cfg = test_cfg - self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient() + self.aiservice_client = ( + aiservice_client if aiservice_client else AiServiceClient() + ) self.function_to_optimize = function_to_optimize self.function_to_optimize_source_code = ( function_to_optimize_source_code @@ -107,19 +130,25 @@ def __init__( if not function_to_optimize_ast: original_module_ast = ast.parse(function_to_optimize_source_code) self.function_to_optimize_ast = get_first_top_level_function_or_method_ast( - function_to_optimize.function_name, function_to_optimize.parents, original_module_ast + function_to_optimize.function_name, + function_to_optimize.parents, + original_module_ast, ) else: self.function_to_optimize_ast = function_to_optimize_ast self.function_to_tests = function_to_tests if function_to_tests else {} self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) - self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None + self.local_aiservice_client = ( + LocalAiServiceClient() if self.experiment_id else None + ) self.test_files = TestFiles(test_files=[]) self.args = args # Check defaults for these self.function_trace_id: str = str(uuid.uuid4()) - self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root) + self.original_module_path = module_name_from_file_path( + self.function_to_optimize.file_path, self.project_root + ) def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None @@ -150,17 +179,22 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.args.project_root, ) - generated_test_paths = [ get_test_file_path( - self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="unit", ) for test_index in range(N_TESTS_TO_GENERATE) ] generated_perf_test_paths = [ get_test_file_path( - self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="perf" + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="perf", ) for test_index in range(N_TESTS_TO_GENERATE) ] @@ -183,7 +217,12 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(generated_results.failure()) generated_tests: GeneratedTestsList optimizations_set: OptimizationSet - generated_tests, function_to_concolic_tests, concolic_test_str, optimizations_set = generated_results.unwrap() + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + ) = generated_results.unwrap() count_tests = len(generated_tests.generated_tests) if concolic_test_str: count_tests += 1 @@ -211,29 +250,39 @@ def optimize_function(self) -> Result[BestOptimization, str]: function_to_optimize_qualified_name = self.function_to_optimize.qualified_name function_to_all_tests = { - key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, []) + key: self.function_to_tests.get(key, []) + + function_to_concolic_tests.get(key, []) for key in set(self.function_to_tests) | set(function_to_concolic_tests) } - instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) + instrumented_unittests_created_for_function = self.instrument_existing_tests( + function_to_all_tests + ) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) for function_source in code_context.helper_functions: if ( - function_source.qualified_name != self.function_to_optimize.qualified_name + function_source.qualified_name + != self.function_to_optimize.qualified_name and "." in function_source.qualified_name ): - file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0]) + file_path_to_helper_classes[function_source.file_path].add( + function_source.qualified_name.split(".")[0] + ) - baseline_result = self.establish_original_code_baseline( # this needs better typing - code_context=code_context, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, + baseline_result = ( + self.establish_original_code_baseline( # this needs better typing + code_context=code_context, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + ) ) console.rule() paths_to_cleanup = ( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + generated_test_paths + + generated_perf_test_paths + + list(instrumented_unittests_created_for_function) ) if not is_successful(baseline_result): @@ -241,7 +290,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(baseline_result.failure()) original_code_baseline, test_functions_to_remove = baseline_result.unwrap() - if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( + if isinstance( + original_code_baseline, OriginalCodeBaseline + ) and not coverage_critic( original_code_baseline.coverage_results, self.args.test_framework ): cleanup_paths(paths_to_cleanup) @@ -249,7 +300,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_optimization = None - for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): + for u, candidates in enumerate( + [optimizations_set.control, optimizations_set.experiment] + ): if candidates is None: continue @@ -260,10 +313,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, ) - ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id}) + ph( + "cli-optimize-function-finished", + {"function_trace_id": self.function_trace_id}, + ) generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, ) if best_optimization: @@ -271,7 +328,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_print(best_optimization.candidate.source_code) console.print( Panel( - best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" + best_optimization.candidate.explanation, + title="Best Candidate Explanation", + border_style="blue", ) ) explanation = Explanation( @@ -287,21 +346,28 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.log_successful_optimization(explanation, generated_tests) self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=best_optimization.candidate.source_code + code_context=code_context, + optimized_code=best_optimization.candidate.source_code, ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + code_context.helper_functions, + explanation.file_path, + self.function_to_optimize_source_code, ) existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + self.function_to_optimize.qualified_name_with_modules_from_root( + self.project_root + ), function_to_all_tests, tests_root=self.test_cfg.tests_root, ) original_code_combined = original_helper_code.copy() - original_code_combined[explanation.file_path] = self.function_to_optimize_source_code + original_code_combined[explanation.file_path] = ( + self.function_to_optimize_source_code + ) new_code_combined = new_helper_code.copy() new_code_combined[explanation.file_path] = new_code if not self.args.no_pr: @@ -311,7 +377,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: else "Coverage data not available" ) generated_tests_str = "\n\n".join( - [test.generated_original_test_source for test in generated_tests.generated_tests] + [ + test.generated_original_test_source + for test in generated_tests.generated_tests + ] ) if concolic_test_str: generated_tests_str += "\n\n" + concolic_test_str @@ -348,7 +417,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: break # need to delete only one test directory if not best_optimization: - return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") + return Failure( + f"No best optimizations found for function {self.function_to_optimize.qualified_name}" + ) return Success(best_optimization) def determine_best_candidate( @@ -374,9 +445,15 @@ def determine_best_candidate( console.rule() try: for candidate_index, candidate in enumerate(candidates, start=1): - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:") + get_run_tmp_file( + Path(f"test_return_values_{candidate_index}.bin") + ).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{candidate_index}.sqlite") + ).unlink(missing_ok=True) + logger.info( + f"Optimization candidate {candidate_index}/{len(candidates)}:" + ) code_print(candidate.source_code) try: did_update = self.replace_function_and_helpers_with_optimized_code( @@ -388,10 +465,17 @@ def determine_best_candidate( ) console.rule() continue - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + except ( + ValueError, + SyntaxError, + cst.ParserSyntaxError, + AttributeError, + ) as e: logger.error(e) self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) continue @@ -414,16 +498,23 @@ def determine_best_candidate( optimized_runtimes[candidate.optimization_id] = best_test_runtime is_correct[candidate.optimization_id] = True perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + original_runtime_ns=original_code_baseline.runtime, + optimized_runtime_ns=best_test_runtime, ) speedup_ratios[candidate.optimization_id] = perf_gain tree = Tree(f"Candidate #{candidate_index} - Runtime Information") if speedup_critic( - candidate_result, original_code_baseline.runtime, best_runtime_until_now + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now, ) and quantity_of_tests_critic(candidate_result): - tree.add("This candidate is faster than the previous best candidate. 🚀") - tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") + tree.add( + "This candidate is faster than the previous best candidate. 🚀" + ) + tree.add( + f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}" + ) tree.add( f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " f"(measured over {candidate_result.max_loop_count} " @@ -452,11 +543,15 @@ def determine_best_candidate( console.rule() self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) except KeyboardInterrupt as e: self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) logger.exception(f"Optimization interrupted: {e}") raise @@ -470,7 +565,9 @@ def determine_best_candidate( ) return best_optimization - def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList) -> None: + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList + ) -> None: explanation_panel = Panel( f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n" f"📈 {explanation.perf_improvement_line}\n" @@ -482,7 +579,12 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: if self.args.no_pr: tests_panel = Panel( Syntax( - "\n".join([test.generated_original_test_source for test in generated_tests.generated_tests]), + "\n".join( + [ + test.generated_original_test_source + for test in generated_tests.generated_tests + ] + ), "python", line_numbers=True, ), @@ -509,7 +611,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: ) @staticmethod - def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None: + def write_code_and_helpers( + original_code: str, original_helper_code: dict[Path, str], path: Path + ) -> None: with path.open("w", encoding="utf8") as f: f.write(original_code) for module_abspath in original_helper_code: @@ -530,7 +634,9 @@ def reformat_code_and_helpers( new_helper_code: dict[Path, str] = {} helper_functions_paths = {hf.file_path for hf in helper_functions} for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + formatted_helper_code = format_code( + self.args.formatter_cmds, module_abspath + ) if should_sort_imports: formatted_helper_code = sort_imports(formatted_helper_code) new_helper_code[module_abspath] = formatted_helper_code @@ -547,8 +653,13 @@ def replace_function_and_helpers_with_optimized_code( ) for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": - read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) - for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): + read_writable_functions_by_file_path[helper_function.file_path].add( + helper_function.qualified_name + ) + for ( + module_abspath, + qualified_names, + ) in read_writable_functions_by_file_path.items(): did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), optimized_code=optimized_code, @@ -559,18 +670,23 @@ def replace_function_and_helpers_with_optimized_code( return did_update def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: - code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) + code_to_optimize, contextual_dunder_methods = extract_code( + [self.function_to_optimize] + ) if code_to_optimize is None: return Failure("Could not find function to optimize.") - (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( - self.function_to_optimize, self.project_root, code_to_optimize + (helper_code, helper_functions, helper_dunder_methods) = ( + get_constrained_function_context_and_helper_functions( + self.function_to_optimize, self.project_root, code_to_optimize + ) ) if self.function_to_optimize.parents: function_class = self.function_to_optimize.parents[0].name same_class_helper_methods = [ df for df in helper_functions - if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class + if df.qualified_name.count(".") > 0 + and df.qualified_name.split(".")[0] == function_class ] optimizable_methods = [ FunctionToOptimize( @@ -589,7 +705,9 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: dedup_optimizable_methods.append(method) added_methods.add(f"{method.file_path}.{method.qualified_name}") if len(dedup_optimizable_methods) > 1: - code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) + code_to_optimize, contextual_dunder_methods = extract_code( + list(reversed(dedup_optimizable_methods)) + ) if code_to_optimize is None: return Failure("Could not find function to optimize.") code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize @@ -626,23 +744,35 @@ def cleanup_leftover_test_return_values() -> None: get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) - def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]: + def instrument_existing_tests( + self, function_to_all_tests: dict[str, list[FunctionCalledInTest]] + ) -> set[Path]: existing_test_files_count = 0 replay_test_files_count = 0 concolic_coverage_test_files_count = 0 unique_instrumented_test_files = set() - func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root) + func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root( + self.project_root + ) if func_qualname not in function_to_all_tests: - logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.") + logger.info( + f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests." + ) console.rule() else: test_file_invocation_positions = defaultdict(list[FunctionCalledInTest]) for tests_in_file in function_to_all_tests.get(func_qualname): test_file_invocation_positions[ - (tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type) + ( + tests_in_file.tests_in_file.test_file, + tests_in_file.tests_in_file.test_type, + ) ].append(tests_in_file) - for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items(): + for ( + test_file, + test_type, + ), tests_in_file_list in test_file_invocation_positions.items(): path_obj_test_file = Path(test_file) if test_type == TestType.EXISTING_UNIT_TEST: existing_test_files_count += 1 @@ -721,9 +851,16 @@ def generate_tests_and_optimizations( generated_test_paths: list[Path], generated_perf_test_paths: list[Path], run_experiment: bool = False, - ) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]: + ) -> Result[ + tuple[ + GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet + ], + str, + ]: assert len(generated_test_paths) == N_TESTS_TO_GENERATE - max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 + max_workers = ( + N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 + ) console.rule() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit the test generation task as future @@ -738,9 +875,17 @@ def generate_tests_and_optimizations( self.aiservice_client.optimize_python_code, read_writable_code, read_only_context_code, - self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, + ( + self.function_trace_id[:-4] + "EXP0" + if run_experiment + else self.function_trace_id + ), N_CANDIDATES, - ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, + ( + ExperimentMetadata(id=self.experiment_id, group="control") + if run_experiment + else None + ), ) future_candidates_exp = None @@ -751,7 +896,11 @@ def generate_tests_and_optimizations( self.function_to_optimize, self.function_to_optimize_ast, ) - futures = [*future_tests, future_optimization_candidates, future_concolic_tests] + futures = [ + *future_tests, + future_optimization_candidates, + future_concolic_tests, + ] if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, @@ -767,11 +916,17 @@ def generate_tests_and_optimizations( concurrent.futures.wait(futures) # Retrieve results - candidates: list[OptimizedCandidate] = future_optimization_candidates.result() + candidates: list[OptimizedCandidate] = ( + future_optimization_candidates.result() + ) if not candidates: - return Failure(f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}") + return Failure( + f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}" + ) - candidates_experiment = future_candidates_exp.result() if future_candidates_exp else None + candidates_experiment = ( + future_candidates_exp.result() if future_candidates_exp else None + ) # Process test generation results @@ -796,10 +951,18 @@ def generate_tests_and_optimizations( ) ) if not tests: - logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}") - return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}") - function_to_concolic_tests, concolic_test_str = future_concolic_tests.result() - logger.info(f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}") + logger.warning( + f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}" + ) + return Failure( + f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}" + ) + function_to_concolic_tests, concolic_test_str = ( + future_concolic_tests.result() + ) + logger.info( + f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}" + ) console.rule() generated_tests = GeneratedTestsList(generated_tests=tests) @@ -819,8 +982,13 @@ def establish_original_code_baseline( file_path_to_helper_classes: dict[Path, set[str]], ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: # For the original function - run the tests and get the runtime, plus coverage - with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"): - assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] + with progress_bar( + f"Establishing original code baseline for {self.function_to_optimize.function_name}" + ): + assert (test_framework := self.args.test_framework) in [ + "pytest", + "unittest", + ] success = True test_env = os.environ.copy() @@ -836,7 +1004,9 @@ def establish_original_code_baseline( # Instrument codeflash capture try: instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + self.function_to_optimize, + file_path_to_helper_classes, + self.test_cfg.tests_root, ) behavioral_results, coverage_results = self.run_and_parse_tests( testing_type=TestingMode.BEHAVIOR, @@ -847,19 +1017,25 @@ def establish_original_code_baseline( enable_coverage=test_framework == "pytest", code_context=code_context, ) - except SideEffectDetectedError as e: - return Failure(f"Side effect detected in original code: {e}, skipping optimization.") + except SideEffectDetected as e: + return Failure( + f"Side effect detected in original code: {e}, skipping optimization." + ) finally: # Remove codeflash capture self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) if not behavioral_results: logger.warning( f"Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION." ) console.rule() - return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") + return Failure( + "Failed to establish a baseline for the original code - bevhavioral tests failed." + ) if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": @@ -901,11 +1077,16 @@ def establish_original_code_baseline( ) console.rule() - total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index + total_timing = ( + benchmarking_results.total_passed_runtime() + ) # caution: doesn't handle the loop index functions_to_remove = [ result.id.test_function_name for result in behavioral_results - if (result.test_type == TestType.GENERATED_REGRESSION and not result.did_pass) + if ( + result.test_type == TestType.GENERATED_REGRESSION + and not result.did_pass + ) ] if total_timing == 0: logger.warning( @@ -914,13 +1095,17 @@ def establish_original_code_baseline( console.rule() success = False if not total_timing: - logger.warning("Failed to run the tests for the original function, skipping optimization") + logger.warning( + "Failed to run the tests for the original function, skipping optimization" + ) console.rule() success = False if not success: return Failure("Failed to establish a baseline for the original code.") - loop_count = max([int(result.loop_index) for result in benchmarking_results.test_results]) + loop_count = max( + [int(result.loop_index) for result in benchmarking_results.test_results] + ) logger.info( f"Original code summed runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: " f"{humanize_runtime(total_timing)} per full loop" @@ -959,17 +1144,27 @@ def run_optimized_candidate( else: test_env["PYTHONPATH"] += os.pathsep + str(self.project_root) - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{optimization_candidate_index}.sqlite") + ).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{optimization_candidate_index}.sqlite") + ).unlink(missing_ok=True) # Instrument codeflash capture - candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text( + "utf-8" + ) candidate_helper_code = {} for module_abspath in original_helper_code: - candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") + candidate_helper_code[module_abspath] = Path(module_abspath).read_text( + "utf-8" + ) try: instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + self.function_to_optimize, + file_path_to_helper_classes, + self.test_cfg.tests_root, ) candidate_behavior_results, _ = self.run_and_parse_tests( @@ -983,7 +1178,9 @@ def run_optimized_candidate( # Remove instrumentation finally: self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path + candidate_fto_code, + candidate_helper_code, + self.function_to_optimize.file_path, ) console.print( TestResults.report_to_tree( @@ -993,13 +1190,19 @@ def run_optimized_candidate( ) console.rule() - if compare_test_results(baseline_results.behavioral_test_results, candidate_behavior_results): + if compare_test_results( + baseline_results.behavioral_test_results, candidate_behavior_results + ): logger.info("Test results matched!") console.rule() else: - logger.info("Test results did not match the test results of the original code.") + logger.info( + "Test results did not match the test results of the original code." + ) console.rule() - return Failure("Test results did not match the test results of the original code.") + return Failure( + "Test results did not match the test results of the original code." + ) if test_framework == "pytest": candidate_benchmarking_results, _ = self.run_and_parse_tests( @@ -1014,7 +1217,8 @@ def run_optimized_candidate( max(all_loop_indices) if ( all_loop_indices := { - result.loop_index for result in candidate_benchmarking_results.test_results + result.loop_index + for result in candidate_benchmarking_results.test_results } ) else 0 @@ -1040,11 +1244,17 @@ def run_optimized_candidate( loop_count = i + 1 candidate_benchmarking_results.merge(unittest_loop_results) - if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: - logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") + if ( + total_candidate_timing := candidate_benchmarking_results.total_passed_runtime() + ) == 0: + logger.warning( + "The overall test runtime of the optimized function is 0, couldn't run tests." + ) console.rule() - logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + logger.debug( + f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}" + ) return Success( OptimizedCandidateResult( max_loop_count=loop_count, @@ -1073,15 +1283,17 @@ def run_and_parse_tests( coverage_database_file = None try: if testing_type == TestingMode.BEHAVIOR: - result_file_path, run_result, coverage_database_file = run_behavioral_tests( - test_files, - test_framework=self.test_cfg.test_framework, - cwd=self.project_root, - test_env=test_env, - pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, - pytest_cmd=self.test_cfg.pytest_cmd, - verbose=True, - enable_coverage=enable_coverage, + result_file_path, run_result, coverage_database_file = ( + run_behavioral_tests( + test_files, + test_framework=self.test_cfg.test_framework, + cwd=self.project_root, + test_env=test_env, + pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + pytest_cmd=self.test_cfg.pytest_cmd, + verbose=True, + enable_coverage=enable_coverage, + ) ) elif testing_type == TestingMode.PERFORMANCE: result_file_path, run_result = run_benchmarking_tests( diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 229f0be64..ad964675e 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -7,13 +7,14 @@ from pathlib import Path from typing import TYPE_CHECKING +from auditwall.core import SideEffectDetected + from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.models.models import TestFiles -from codeflash.verification._auditwall import SideEffectDetectedError from codeflash.verification.codeflash_auditwall import transform_code from codeflash.verification.test_results import TestType @@ -102,7 +103,7 @@ def run_behavioral_tests( match = re.search(r"codeflash has detected: (.+).", line_co) if match: msg = match.group(1) - raise SideEffectDetectedError(msg) + raise SideEffectDetected(msg) logger.debug(auditing_res.stderr) logger.debug(auditing_res.stdout) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7f4a94845..88b46e87c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest + from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent From 875dc699e2f2151da14cf2035bcca3cd7906d0ab Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 17 Mar 2025 03:21:25 -0700 Subject: [PATCH 20/20] Update test_runner.py --- codeflash/verification/test_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index ad964675e..21813fd87 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -90,17 +90,18 @@ def run_behavioral_tests( env=pytest_test_env, timeout=600, ) + logger.info(auditing_res.stdout) if auditing_res.returncode != 0: line_co = next( ( line for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines() - if "codeflash.verification._auditwall.SideEffectDetectedError" in line + if "auditwall.core.SideEffectDetected" in line ), None, ) if line_co: - match = re.search(r"codeflash has detected: (.+).", line_co) + match = re.search(r"auditwall.core.SideEffectDetected: A (.+).", line_co) if match: msg = match.group(1) raise SideEffectDetected(msg)