diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..bc0a20ae8 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,19 @@ +name: Lint +on: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Run pre-commit hooks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - uses: pre-commit/action@v3.0.1 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..9c2955e4b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.11.0" + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml] + - id: ruff-format \ No newline at end of file diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index f6482898b..e3b02444a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -1,10 +1,9 @@ from __future__ import annotations -import time - import json import os import platform +import time from typing import TYPE_CHECKING, Any import requests @@ -177,7 +176,7 @@ def optimize_python_code_line_profiler( logger.info("Generating optimized candidates…") console.rule() - if line_profiler_results=="": + if line_profiler_results == "": logger.info("No LineProfiler results were provided, Skipping optimization.") console.rule() return [] @@ -209,7 +208,6 @@ def optimize_python_code_line_profiler( console.rule() return [] - def log_results( self, function_trace_id: str, diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 35232f954..10abf4583 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -69,7 +69,7 @@ def write_function_timings(self) -> None: "(function_name, class_name, module_name, file_path, benchmark_function_name, " "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - self.function_calls_data + self.function_calls_data, ) self._connection.commit() self.function_calls_data = [] @@ -100,7 +100,8 @@ def __call__(self, func: Callable) -> Callable: The wrapped function """ - func_id = (func.__module__,func.__name__) + func_id = (func.__module__, func.__name__) + @functools.wraps(func) def wrapper(*args, **kwargs): # Initialize thread-local active functions set if it doesn't exist @@ -139,9 +140,19 @@ def wrapper(*args, **kwargs): self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, None, None) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) ) return result @@ -155,9 +166,19 @@ def wrapper(*args, **kwargs): self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, None, None) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) ) return result # Flush to database every 100 calls @@ -168,12 +189,24 @@ def wrapper(*args, **kwargs): self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, pickled_args, pickled_kwargs) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + pickled_args, + pickled_kwargs, + ) ) return result + return wrapper + # Create a singleton instance codeflash_trace = CodeflashTrace() diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index 044b0b0a4..2044d0997 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -13,22 +13,20 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None: self.added_codeflash_trace = False self.class_name = "" self.function_name = "" - self.decorator = cst.Decorator( - decorator=cst.Name(value="codeflash_trace") - ) + self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace")) def leave_ClassDef(self, original_node, updated_node): if self.class_name == original_node.name.value: - self.class_name = "" # Even if nested classes are not visited, this function is still called on them + self.class_name = "" # Even if nested classes are not visited, this function is still called on them return updated_node def visit_ClassDef(self, node): - if self.class_name: # Don't go into nested class + if self.class_name: # Don't go into nested class return False self.class_name = node.name.value def visit_FunctionDef(self, node): - if self.function_name: # Don't go into nested function + if self.function_name: # Don't go into nested function return False self.function_name = node.name.value @@ -39,9 +37,7 @@ def leave_FunctionDef(self, original_node, updated_node): # Add the new decorator after any existing decorators, so it gets executed first updated_decorators = list(updated_node.decorators) + [self.decorator] self.added_codeflash_trace = True - return updated_node.with_changes( - decorators=updated_decorators - ) + return updated_node.with_changes(decorators=updated_decorators) return updated_node @@ -53,17 +49,10 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c body=[ cst.ImportFrom( module=cst.Attribute( - value=cst.Attribute( - value=cst.Name(value="codeflash"), - attr=cst.Name(value="benchmarking") - ), - attr=cst.Name(value="codeflash_trace") + value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")), + attr=cst.Name(value="codeflash_trace"), ), - names=[ - cst.ImportAlias( - name=cst.Name(value="codeflash_trace") - ) - ] + names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))], ) ] ) @@ -73,6 +62,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node.with_changes(body=new_body) + def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str: """Add codeflash_trace to a function. @@ -91,25 +81,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct class_name = function_to_optimize.parents[0].name target_functions.add((class_name, function_to_optimize.function_name)) - transformer = AddDecoratorTransformer( - target_functions = target_functions, - ) + transformer = AddDecoratorTransformer(target_functions=target_functions) module = cst.parse_module(code) modified_module = module.visit(transformer) return modified_module.code -def instrument_codeflash_trace_decorator( - file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] -) -> None: +def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None: """Instrument codeflash_trace decorator to functions to optimize.""" for file_path, functions_to_optimize in file_to_funcs_to_optimize.items(): original_code = file_path.read_text(encoding="utf-8") - new_code = add_codeflash_decorator_to_code( - original_code, - functions_to_optimize - ) + new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize) # Modify the code modified_code = isort.code(code=new_code, float_to_top=True) diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 313817041..d66af456f 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -17,10 +17,11 @@ class CodeFlashBenchmarkPlugin: def __init__(self) -> None: self._trace_path = None self._connection = None + self._cursor = None self.project_root = None self.benchmark_timings = [] - def setup(self, trace_path:str, project_root:str) -> None: + def setup(self, trace_path: str, project_root: str) -> None: try: # Open connection self.project_root = project_root @@ -35,7 +36,7 @@ def setup(self, trace_path:str, project_root:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() # Reopen only at the end of pytest session except Exception as e: print(f"Database setup error: {e}") if self._connection: @@ -47,22 +48,21 @@ def write_benchmark_timings(self) -> None: if not self.benchmark_timings: return # No data to write - if self._connection is None: - self._connection = sqlite3.connect(self._trace_path) + self._ensure_connection() try: - cur = self._connection.cursor() # Insert data into the benchmark_timings table - cur.executemany( + self._cursor.executemany( "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings + self.benchmark_timings, ) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list + self.benchmark_timings.clear() # Clear the benchmark timings list (reuses the list object) except Exception as e: print(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -196,12 +196,7 @@ def pytest_sessionfinish(self, session, exitstatus): @staticmethod def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod def pytest_plugin_registered(plugin, manager): @@ -213,9 +208,9 @@ def pytest_plugin_registered(plugin, manager): def pytest_configure(config): """Register the benchmark marker.""" config.addinivalue_line( - "markers", - "benchmark: mark test as a benchmark that should be run with codeflash tracing" + "markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing" ) + @staticmethod def pytest_collection_modifyitems(config, items): # Skip tests that don't have the benchmark fixture @@ -248,16 +243,19 @@ def __call__(self, func, *args, **kwargs): if args or kwargs: # Used as benchmark(func, *args, **kwargs) return self._run_benchmark(func, *args, **kwargs) + # Used as @benchmark decorator def wrapped_func(*args, **kwargs): return func(*args, **kwargs) + result = self._run_benchmark(func) return wrapped_func def _run_benchmark(self, func, *args, **kwargs): """Actual benchmark implementation.""" - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), - Path(codeflash_benchmark_plugin.project_root)) + benchmark_module_path = module_name_from_file_path( + Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root) + ) benchmark_function_name = self.request.node.name line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # Set env vars @@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs): codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_module_path, benchmark_function_name, line_number, end - start)) + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) return result @@ -290,4 +289,16 @@ def benchmark(request): return CodeFlashBenchmarkPlugin.Benchmark(request) + def _ensure_connection(self) -> None: + # Establish DB connection and optimize settings for faster inserts, if not already done + if self._connection is None: + self._connection = sqlite3.connect(self._trace_path) + self._cursor = self._connection.cursor() + # Speed up inserts by relaxing durability + self._cursor.execute("PRAGMA synchronous = OFF") + self._cursor.execute("PRAGMA journal_mode = MEMORY") + elif self._cursor is None: + self._cursor = self._connection.cursor() + + codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index ee1107241..043abfdc0 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -16,7 +16,12 @@ def get_next_arg_and_return( - trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256 + trace_file: str, + benchmark_function_name: str, + function_name: str, + file_path: str, + class_name: str | None = None, + num_to_get: int = 256, ) -> Generator[Any]: db = sqlite3.connect(trace_file) cur = db.cursor() @@ -42,10 +47,7 @@ def get_function_alias(module: str, function_name: str) -> str: def create_trace_replay_test_code( - trace_file: str, - functions_data: list[dict[str, Any]], - test_framework: str = "pytest", - max_run_count=256 + trace_file: str, functions_data: list[dict[str, Any]], test_framework: str = "pytest", max_run_count=256 ) -> str: """Create a replay test for functions based on trace data. @@ -83,8 +85,9 @@ def create_trace_replay_test_code( imports += "\n".join(function_imports) - functions_to_optimize = sorted({func.get("function_name") for func in functions_data - if func.get("function_name") != "__init__"}) + functions_to_optimize = sorted( + {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} + ) metadata = f"""functions = {functions_to_optimize} trace_file_path = r"{trace_file}" """ @@ -111,7 +114,8 @@ def create_trace_replay_test_code( else: instance = args[0] # self ret = instance{method_name}(*args[1:], **kwargs) - """) + """ + ) test_class_method_body = textwrap.dedent( """\ @@ -142,7 +146,6 @@ def create_trace_replay_test_code( self = "" for func in functions_data: - module_name = func.get("module_name") function_name = func.get("function_name") class_name = func.get("class_name") @@ -206,7 +209,10 @@ def create_trace_replay_test_code( return imports + "\n" + metadata + "\n" + test_template -def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int: + +def generate_replay_test( + trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100 +) -> int: """Generate multiple replay tests from the traced function calls, grouped by benchmark. Args: @@ -226,9 +232,7 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework cursor = conn.cursor() # Get distinct benchmark file paths - cursor.execute( - "SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings" - ) + cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings") benchmark_files = cursor.fetchall() # Generate a test for each benchmark file @@ -236,29 +240,29 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework benchmark_module_path = benchmark_file[0] # Get all benchmarks and functions associated with this file path cursor.execute( - "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " + "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " "WHERE benchmark_module_path = ?", - (benchmark_module_path,) + (benchmark_module_path,), ) functions_data = [] for row in cursor.fetchall(): benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row # Add this function to our list - functions_data.append({ - "function_name": function_name, - "class_name": class_name, - "file_path": file_path, - "module_name": module_name, - "benchmark_function_name": benchmark_function_name, - "benchmark_module_path": benchmark_module_path, - "benchmark_line_number": benchmark_line_number, - "function_properties": inspect_top_level_functions_or_methods( - file_name=Path(file_path), - function_or_method_name=function_name, - class_name=class_name, - ) - }) + functions_data.append( + { + "function_name": function_name, + "class_name": class_name, + "file_path": file_path, + "module_name": module_name, + "benchmark_function_name": benchmark_function_name, + "benchmark_module_path": benchmark_module_path, + "benchmark_line_number": benchmark_line_number, + "function_properties": inspect_top_level_functions_or_methods( + file_name=Path(file_path), function_or_method_name=function_name, class_name=class_name + ), + } + ) if not functions_data: logger.info(f"No benchmark test functions found in {benchmark_module_path}") diff --git a/codeflash/benchmarking/trace_benchmarks.py b/codeflash/benchmarking/trace_benchmarks.py index 8d14068e7..e59b06656 100644 --- a/codeflash/benchmarking/trace_benchmarks.py +++ b/codeflash/benchmarking/trace_benchmarks.py @@ -9,7 +9,9 @@ from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE -def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None: +def trace_benchmarks_pytest( + benchmarks_root: Path, tests_root: Path, project_root: Path, trace_file: Path, timeout: int = 300 +) -> None: benchmark_env = os.environ.copy() if "PYTHONPATH" not in benchmark_env: benchmark_env["PYTHONPATH"] = str(project_root) @@ -43,6 +45,4 @@ def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root error_section = match.group(1) if match else result.stdout else: error_section = result.stdout - logger.warning( - f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}" - ) + logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}") diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py index da09cd57a..5dae99444 100644 --- a/codeflash/benchmarking/utils.py +++ b/codeflash/benchmarking/utils.py @@ -15,8 +15,9 @@ from codeflash.models.models import BenchmarkKey -def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], - total_benchmark_timings: dict[BenchmarkKey, int]) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: +def validate_and_format_benchmark_table( + function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int] +) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: function_to_result = {} # Process each function's benchmark data for func_path, test_times in function_benchmark_timings.items(): @@ -41,12 +42,11 @@ def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, di def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None: - try: terminal_width = int(shutil.get_terminal_size().columns * 0.9) except Exception: terminal_width = 120 # Fallback width - console = Console(width = terminal_width) + console = Console(width=terminal_width) for func_path, sorted_tests in function_to_results.items(): console.print() function_name = func_path.split(":")[-1] @@ -67,30 +67,18 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey test_function = benchmark_key.function_name if total_time == 0.0: - table.add_row( - module_path, - test_function, - "N/A", - "N/A", - "N/A" - ) + table.add_row(module_path, test_function, "N/A", "N/A", "N/A") else: - table.add_row( - module_path, - test_function, - f"{total_time:.3f}", - f"{func_time:.3f}", - f"{percentage:.2f}" - ) + table.add_row(module_path, test_function, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}") # Print the table console.print(table) def process_benchmark_data( - replay_performance_gain: dict[BenchmarkKey, float], - fto_benchmark_timings: dict[BenchmarkKey, int], - total_benchmark_timings: dict[BenchmarkKey, int] + replay_performance_gain: dict[BenchmarkKey, float], + fto_benchmark_timings: dict[BenchmarkKey, int], + total_benchmark_timings: dict[BenchmarkKey, int], ) -> Optional[ProcessedBenchmarkInfo]: """Process benchmark data and generate detailed benchmark information. @@ -109,19 +97,25 @@ def process_benchmark_data( benchmark_details = [] for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items(): - total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0) if total_benchmark_timing == 0: continue # Skip benchmarks with zero timing # Calculate expected new benchmark timing - expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + ( - 1 / (replay_performance_gain[benchmark_key] + 1) - ) * og_benchmark_timing + expected_new_benchmark_timing = ( + total_benchmark_timing + - og_benchmark_timing + + (1 / (replay_performance_gain[benchmark_key] + 1)) * og_benchmark_timing + ) # Calculate speedup - benchmark_speedup_percent = performance_gain(original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)) * 100 + benchmark_speedup_percent = ( + performance_gain( + original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing) + ) + * 100 + ) benchmark_details.append( BenchmarkDetail( @@ -129,7 +123,7 @@ def process_benchmark_data( test_function=benchmark_key.function_name, original_timing=humanize_runtime(int(total_benchmark_timing)), expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)), - speedup_percent=benchmark_speedup_percent + speedup_percent=benchmark_speedup_percent, ) ) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index c0e90ad4b..cce7208da 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -62,9 +62,13 @@ def parse_args() -> Namespace: ) parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs") parser.add_argument("--version", action="store_true", help="Print the version of codeflash") - parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks") parser.add_argument( - "--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located." + "--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks" + ) + parser.add_argument( + "--benchmarks-root", + type=str, + help="Path to the directory of the project, where all the pytest-benchmark tests are located.", ) args: Namespace = parser.parse_args() return process_and_validate_cmd_args(args) @@ -134,7 +138,9 @@ def process_pyproject_config(args: Namespace) -> Namespace: assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory" if args.benchmark: assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark" - assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory" + assert Path(args.benchmarks_root).is_dir(), ( + f"--benchmarks-root {args.benchmarks_root} must be a valid directory" + ) assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), ( f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}" ) diff --git a/codeflash/cli_cmds/cli_common.py b/codeflash/cli_cmds/cli_common.py index 942b8f634..b8f04ec6e 100644 --- a/codeflash/cli_cmds/cli_common.py +++ b/codeflash/cli_cmds/cli_common.py @@ -74,7 +74,7 @@ def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str] | None: new_kwargs["message"] = last_message new_args.append(args[0]) - return cast(dict[str, str], inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)])) + return cast("dict[str, str]", inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)])) def split_string_to_fit_width(string: str, width: int) -> list[str]: diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 7e6d2cd57..740506e85 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -239,7 +239,7 @@ def collect_setup_info() -> SetupInfo: else: apologize_and_exit() else: - tests_root = Path(curdir) / Path(cast(str, tests_root_answer)) + tests_root = Path(curdir) / Path(cast("str", tests_root_answer)) tests_root = tests_root.relative_to(curdir) ph("cli-tests-root-provided") @@ -302,7 +302,7 @@ def collect_setup_info() -> SetupInfo: elif benchmarks_answer == no_benchmarks_option: benchmarks_root = None else: - benchmarks_root = tests_root / Path(cast(str, benchmarks_answer)) + benchmarks_root = tests_root / Path(cast("str", benchmarks_answer)) # TODO: Implement other benchmark framework options # if benchmarks_root: @@ -354,9 +354,9 @@ def collect_setup_info() -> SetupInfo: module_root=str(module_root), tests_root=str(tests_root), benchmarks_root=str(benchmarks_root) if benchmarks_root else None, - test_framework=cast(str, test_framework), + test_framework=cast("str", test_framework), ignore_paths=ignore_paths, - formatter=cast(str, formatter), + formatter=cast("str", formatter), git_remote=str(git_remote), ) @@ -466,7 +466,7 @@ def check_for_toml_or_setup_file() -> str | None: click.echo("⏩️ Skipping pyproject.toml creation.") apologize_and_exit() click.echo() - return cast(str, project_name) + return cast("str", project_name) def install_github_actions(override_formatter_check: bool = False) -> None: diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index b4bfda3ff..fe2fdcdd1 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -12,7 +12,6 @@ MofNCompleteColumn, Progress, SpinnerColumn, - TaskProgressColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, @@ -31,15 +30,7 @@ console = Console() logging.basicConfig( level=logging.INFO, - handlers=[ - RichHandler( - rich_tracebacks=True, - markup=False, - console=console, - show_path=False, - show_time=False, - ) - ], + handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], format=BARE_LOGGING_FORMAT, ) @@ -48,9 +39,7 @@ def paneled_text( - text: str, - panel_args: dict[str, str | bool] | None = None, - text_args: dict[str, str] | None = None, + text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None ) -> None: """Print text in a panel.""" from rich.panel import Panel @@ -77,9 +66,7 @@ def code_print(code_str: str) -> None: @contextmanager -def progress_bar( - message: str, *, transient: bool = False -) -> Generator[TaskID, None, None]: +def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]: """Display a progress bar with a spinner and elapsed time.""" progress = Progress( SpinnerColumn(next(spinners)), @@ -94,18 +81,12 @@ def progress_bar( @contextmanager -def test_files_progress_bar( - total: int, description: str -) -> Generator[tuple[Progress, TaskID], None, None]: +def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]: """Progress bar for test files.""" with Progress( SpinnerColumn(next(spinners)), TextColumn("[progress.description]{task.description}"), - BarColumn( - complete_style="cyan", - finished_style="green", - pulse_style="yellow", - ), + BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"), MofNCompleteColumn(), TimeElapsedColumn(), TimeRemainingColumn(), diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index bcbc0e29d..593f0b9cb 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -18,7 +18,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from typing import List, Union +from typing import List + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" @@ -136,15 +137,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Add the new assignments for assignment in assignments_to_append: - new_statements.append( - cst.SimpleStatementLine( - [assignment], - leading_lines=[cst.EmptyLine()] - ) - ) + new_statements.append(cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])) return updated_node.with_changes(body=new_statements) + class GlobalStatementCollector(cst.CSTVisitor): """Visitor that collects all global statements (excluding imports and functions/classes).""" @@ -240,6 +237,7 @@ def find_last_import_line(target_code: str) -> int: module.visit(finder) return finder.last_import_line + class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ccb935f42..b47ef0b5f 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -8,7 +8,7 @@ import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments +from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module from codeflash.models.models import FunctionParent if TYPE_CHECKING: diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 79a39168b..9a678d40f 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -87,7 +87,7 @@ def parse_config_file( "In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest." ) if len(config["formatter-cmds"]) > 0: - #see if this is happening during GitHub actions setup + # see if this is happening during GitHub actions setup if not override_formatter_check: assert config["formatter-cmds"][0] != "your-formatter $file", ( "The formatter command is not set correctly in pyproject.toml. Please set the " diff --git a/codeflash/code_utils/github_utils.py b/codeflash/code_utils/github_utils.py index 2b053a326..53398eeb0 100644 --- a/codeflash/code_utils/github_utils.py +++ b/codeflash/code_utils/github_utils.py @@ -27,5 +27,6 @@ def require_github_app_or_exit(owner: str, repo: str) -> None: ) apologize_and_exit() + def github_pr_url(owner: str, repo: str, pr_number: str) -> str: return f"https://github.com/{owner}/{repo}/pull/{pr_number}" diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 21768cf68..497e3f65c 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -1,4 +1,5 @@ """Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" + from collections import defaultdict from pathlib import Path from typing import Union @@ -12,7 +13,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer): """Transformer that adds a decorator to a function with a specific qualified name.""" - #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure + # TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure def __init__(self, qualified_name: str, decorator_name: str): """Initialize the transformer. @@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu function_name = original_node.name.value # Check if the current context path matches our target qualified name - if self.context_stack==self.qualified_name_parts: + if self.context_stack == self.qualified_name_parts: # Check if the decorator is already present has_decorator = any( - self._is_target_decorator(decorator.decorator) - for decorator in original_node.decorators + self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators ) # Only add the decorator if it's not already there if not has_decorator: - new_decorator = cst.Decorator( - decorator=cst.Name(value=self.decorator_name) - ) + new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name)) # Add our new decorator to the existing decorators updated_decorators = [new_decorator] + list(updated_node.decorators) - updated_node = updated_node.with_changes( - decorators=tuple(updated_decorators) - ) + updated_node = updated_node.with_changes(decorators=tuple(updated_decorators)) # Pop the context when we leave a function self.context_stack.pop() @@ -76,8 +72,9 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs return decorator_node.func.value == self.decorator_name return False + class ProfileEnableTransformer(cst.CSTTransformer): - def __init__(self,filename): + def __init__(self, filename): # Flag to track if we found the import statement self.found_import = False # Track indentation of the import statement @@ -86,12 +83,14 @@ def __init__(self,filename): def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: # Check if this is the line profiler import statement - if (isinstance(original_node.module, cst.Name) and - original_node.module.value == "line_profiler" and - any(name.name.value == "profile" and - (not name.asname or name.asname.name.value == "codeflash_line_profile") - for name in original_node.names)): - + if ( + isinstance(original_node.module, cst.Name) + and original_node.module.value == "line_profiler" + and any( + name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in original_node.names + ) + ): self.found_import = True # Get the indentation from the original node if hasattr(original_node, "leading_lines"): @@ -113,11 +112,15 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if isinstance(stmt, cst.SimpleStatementLine): for small_stmt in stmt.body: if isinstance(small_stmt, cst.ImportFrom): - if (isinstance(small_stmt.module, cst.Name) and - small_stmt.module.value == "line_profiler" and - any(name.name.value == "profile" and - (not name.asname or name.asname.name.value == "codeflash_line_profile") - for name in small_stmt.names)): + if ( + isinstance(small_stmt.module, cst.Name) + and small_stmt.module.value == "line_profiler" + and any( + name.name.value == "profile" + and (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in small_stmt.names + ) + ): import_index = i break if import_index is not None: @@ -125,9 +128,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if import_index is not None: # Create the new enable statement to insert after the import - enable_statement = cst.parse_statement( - f"codeflash_line_profile.enable(output_prefix='{self.filename}')" - ) + enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')") # Insert the new statement after the import statement new_body.insert(import_index + 1, enable_statement) @@ -135,6 +136,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c # Create a new module with the updated body return updated_node.with_changes(body=new_body) + def add_decorator_to_qualified_function(module, qualified_name, decorator_name): """Add a decorator to a function with the exact qualified name in the source code. @@ -156,6 +158,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name): # Convert the modified CST back to source code return modified_module + def add_profile_enable(original_code: str, line_profile_output_file: str) -> str: # TODO modify by using a libcst transformer module = cst.parse_module(original_code) @@ -178,9 +181,7 @@ def leave_Module(self, original_node, updated_node): import_node = cst.parse_statement(self.import_statement) # Add the import to the module's body - return updated_node.with_changes( - body=[import_node] + list(updated_node.body) - ) + return updated_node.with_changes(body=[import_node] + list(updated_node.body)) def visit_ImportFrom(self, node): # Check if the profile is already imported from line_profiler @@ -192,15 +193,15 @@ def visit_ImportFrom(self, node): def add_decorator_imports(function_to_optimize, code_context): """Adds a profile decorator to a function in a Python file and all its helper functions.""" - #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root - #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile + # self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + # grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile file_paths = defaultdict(list) line_profile_output_file = get_run_tmp_file(Path("baseline_lprof")) file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name) for elem in code_context.helper_functions: file_paths[elem.file_path].append(elem.qualified_name) - for file_path,fns_present in file_paths.items(): - #open file + for file_path, fns_present in file_paths.items(): + # open file file_contents = file_path.read_text("utf-8") # parse to cst module_node = cst.parse_module(file_contents) @@ -216,8 +217,8 @@ def add_decorator_imports(function_to_optimize, code_context): # write to file with open(file_path, "w", encoding="utf-8") as file: file.write(modified_code) - #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files + # Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files file_contents = function_to_optimize.file_path.read_text("utf-8") - modified_code = add_profile_enable(file_contents,str(line_profile_output_file)) - function_to_optimize.file_path.write_text(modified_code,"utf-8") + modified_code = add_profile_enable(file_contents, str(line_profile_output_file)) + function_to_optimize.file_path.write_text(modified_code, "utf-8") return line_profile_output_file diff --git a/codeflash/code_utils/tabulate.py b/codeflash/code_utils/tabulate.py index c75dcd03e..6b1b827a4 100644 --- a/codeflash/code_utils/tabulate.py +++ b/codeflash/code_utils/tabulate.py @@ -2,14 +2,16 @@ """Pretty-print tabular data.""" +import dataclasses +import math +import re import warnings from collections import namedtuple from collections.abc import Iterable -from itertools import chain, zip_longest as izip_longest from functools import reduce -import re -import math -import dataclasses +from itertools import chain +from itertools import zip_longest as izip_longest + import wcwidth # optional wide-character (CJK) support __all__ = ["tabulate", "tabulate_formats"] @@ -59,8 +61,7 @@ def _is_separating_line_value(value): def _is_separating_line(row): row_type = type(row) is_sl = (row_type == list or row_type == str) and ( - (len(row) >= 1 and _is_separating_line_value(row[0])) - or (len(row) >= 2 and _is_separating_line_value(row[1])) + (len(row) >= 1 and _is_separating_line_value(row[0])) or (len(row) >= 2 and _is_separating_line_value(row[1])) ) return is_sl @@ -68,26 +69,28 @@ def _is_separating_line(row): def _pipe_segment_with_colons(align, colwidth): """Return a segment of a horizontal line with optional colons which - indicate column's alignment (as in `pipe` output format).""" + indicate column's alignment (as in `pipe` output format). + """ w = colwidth if align in {"right", "decimal"}: return ("-" * (w - 1)) + ":" - elif align == "center": + if align == "center": return ":" + ("-" * (w - 2)) + ":" - elif align == "left": + if align == "left": return ":" + ("-" * (w - 1)) - else: - return "-" * w + return "-" * w def _pipe_line_with_colons(colwidths, colaligns): """Return a horizontal line with optional colons to indicate column's - alignment (as in `pipe` output format).""" + alignment (as in `pipe` output format). + """ if not colaligns: # e.g. printing an empty data frame (github issue #15) colaligns = [""] * len(colwidths) segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] return "|" + "|".join(segments) + "|" + _table_formats = { "simple": TableFormat( lineabove=Line("", "-", " ", ""), @@ -111,16 +114,12 @@ def _pipe_line_with_colons(colwidths, colaligns): ), } -tabulate_formats = list(sorted(_table_formats.keys())) +tabulate_formats = sorted(_table_formats.keys()) # The table formats for which multiline cells will be folded into subsequent # table rows. The key is the original format specified at the API. The value is # the format that will be used to represent the original format. -multiline_formats = { - "plain": "plain", - "pipe": "pipe", - -} +multiline_formats = {"plain": "plain", "pipe": "pipe"} _multiline_codes = re.compile(r"\r|\n|\r\n") _multiline_codes_bytes = re.compile(b"\r|\n|\r\n") @@ -152,9 +151,8 @@ def _pipe_line_with_colons(colwidths, colaligns): _ansi_codes_bytes = re.compile(_ansi_escape_pat.encode("utf8"), re.VERBOSE) _ansi_color_reset_code = "\033[0m" -_float_with_thousands_separators = re.compile( - r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$" -) +_float_with_thousands_separators = re.compile(r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$") + def _isnumber_with_thousands_separator(string): try: @@ -202,16 +200,12 @@ def _isint(string, inttype=int): (hasattr(string, "is_integer") or hasattr(string, "__array__")) and str(type(string)).startswith("= 0: - return len(string) - pos - 1 - else: - return -1 # no point - else: - return -1 # not a number + pos = string.rfind(".") + pos = string.lower().rfind("e") if pos < 0 else pos + if pos >= 0: + return len(string) - pos - 1 + return -1 # no point + return -1 # not a number def _padleft(width, s): @@ -281,8 +263,8 @@ def _padnone(ignore_width, s): def _strip_ansi(s): if isinstance(s, str): return _ansi_codes.sub(r"\4", s) - else: # a bytestring - return _ansi_codes_bytes.sub(r"\4", s) + # a bytestring + return _ansi_codes_bytes.sub(r"\4", s) def _visible_width(s): @@ -292,15 +274,14 @@ def _visible_width(s): len_fn = len if isinstance(s, (str, bytes)): return len_fn(_strip_ansi(s)) - else: - return len_fn(str(s)) + return len_fn(str(s)) def _is_multiline(s): if isinstance(s, str): return bool(re.search(_multiline_codes, s)) - else: # a bytestring - return bool(re.search(_multiline_codes_bytes, s)) + # a bytestring + return bool(re.search(_multiline_codes_bytes, s)) def _multiline_width(multiline_s, line_width_fn=len): @@ -384,65 +365,40 @@ def _align_column( is_multiline=False, preserve_whitespace=False, ): - strings, padfn = _align_column_choose_padfn( - strings, alignment, has_invisible, preserve_whitespace - ) - width_fn = _align_column_choose_width_fn( - has_invisible, enable_widechars, is_multiline - ) + strings, padfn = _align_column_choose_padfn(strings, alignment, has_invisible, preserve_whitespace) + width_fn = _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline) s_widths = list(map(width_fn, strings)) maxwidth = max(max(_flat_list(s_widths)), minwidth) # TODO: refactor column alignment in single-line and multiline modes if is_multiline: if not enable_widechars and not has_invisible: - padded_strings = [ - "\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) - for ms in strings - ] + padded_strings = ["\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) for ms in strings] else: # enable wide-character width corrections s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings] - visible_widths = [ - [maxwidth - (w - l) for w, l in zip(mw, ml)] - for mw, ml in zip(s_widths, s_lens) - ] + visible_widths = [[maxwidth - (w - l) for w, l in zip(mw, ml)] for mw, ml in zip(s_widths, s_lens)] # wcswidth and _visible_width don't count invisible characters; # padfn doesn't need to apply another correction padded_strings = [ "\n".join([padfn(w, s) for s, w in zip((ms.splitlines() or ms), mw)]) for ms, mw in zip(strings, visible_widths) ] - else: # single-line cell values - if not enable_widechars and not has_invisible: - padded_strings = [padfn(maxwidth, s) for s in strings] - else: - # enable wide-character width corrections - s_lens = list(map(len, strings)) - visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)] - # wcswidth and _visible_width don't count invisible characters; - # padfn doesn't need to apply another correction - padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)] + elif not enable_widechars and not has_invisible: + padded_strings = [padfn(maxwidth, s) for s in strings] + else: + # enable wide-character width corrections + s_lens = list(map(len, strings)) + visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)] + # wcswidth and _visible_width don't count invisible characters; + # padfn doesn't need to apply another correction + padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)] return padded_strings def _more_generic(type1, type2): - types = { - type(None): 0, - bool: 1, - int: 2, - float: 3, - bytes: 4, - str: 5, - } - invtypes = { - 5: str, - 4: bytes, - 3: float, - 2: int, - 1: bool, - 0: type(None), - } + types = {type(None): 0, bool: 1, int: 2, float: 3, bytes: 4, str: 5} + invtypes = {5: str, 4: bytes, 3: float, 2: int, 1: bool, 0: type(None)} moregeneric = max(types.get(type1, 5), types.get(type2, 5)) return invtypes[moregeneric] @@ -460,26 +416,20 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): if valtype is str: return f"{val}" - elif valtype is int: + if valtype is int: if isinstance(val, str): val_striped = val.encode("unicode_escape").decode("utf-8") - colored = re.search( - r"(\\[xX]+[0-9a-fA-F]+\[\d+[mM]+)([0-9.]+)(\\.*)$", val_striped - ) + colored = re.search(r"(\\[xX]+[0-9a-fA-F]+\[\d+[mM]+)([0-9.]+)(\\.*)$", val_striped) if colored: total_groups = len(colored.groups()) if total_groups == 3: digits = colored.group(2) if digits.isdigit(): - val_new = ( - colored.group(1) - + format(int(digits), intfmt) - + colored.group(3) - ) + val_new = colored.group(1) + format(int(digits), intfmt) + colored.group(3) val = val_new.encode("utf-8").decode("unicode_escape") intfmt = "" return format(val, intfmt) - elif valtype is bytes: + if valtype is bytes: try: return str(val, "ascii") except (TypeError, UnicodeDecodeError): @@ -490,35 +440,29 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): raw_val = _strip_ansi(val) formatted_val = format(float(raw_val), floatfmt) return val.replace(raw_val, formatted_val) - else: - if isinstance(val, str) and "," in val: - val = val.replace(",", "") # handle thousands-separators - return format(float(val), floatfmt) + if isinstance(val, str) and "," in val: + val = val.replace(",", "") # handle thousands-separators + return format(float(val), floatfmt) else: return f"{val}" -def _align_header( - header, alignment, width, visible_width, is_multiline=False, width_fn=None -): - "Pad string header to width chars given known visible_width of the header." +def _align_header(header, alignment, width, visible_width, is_multiline=False, width_fn=None): + """Pad string header to width chars given known visible_width of the header.""" if is_multiline: header_lines = re.split(_multiline_codes, header) - padded_lines = [ - _align_header(h, alignment, width, width_fn(h)) for h in header_lines - ] + padded_lines = [_align_header(h, alignment, width, width_fn(h)) for h in header_lines] return "\n".join(padded_lines) # else: not multiline ninvisible = len(header) - visible_width width += ninvisible if alignment == "left": return _padright(width, header) - elif alignment == "center": + if alignment == "center": return _padboth(width, header) - elif not alignment: + if not alignment: return f"{header}" - else: - return _padleft(width, header) + return _padleft(width, header) def _remove_separating_lines(rows): @@ -531,11 +475,11 @@ def _remove_separating_lines(rows): else: sans_rows.append(row) return sans_rows, separating_lines - else: - return rows, None + return rows, None + def _bool(val): - "A wrapper around standard bool() which doesn't throw on NumPy arrays" + """A wrapper around standard bool() which doesn't throw on NumPy arrays""" try: return bool(val) except ValueError: # val is likely to be a numpy array with many elements @@ -556,23 +500,18 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): index = None if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"): # dict-like and pandas.DataFrame? - if hasattr(tabular_data.values, "__call__"): + if callable(tabular_data.values): # likely a conventional dict keys = tabular_data.keys() try: - rows = list( - izip_longest(*tabular_data.values()) - ) # columns have to be transposed + rows = list(izip_longest(*tabular_data.values())) # columns have to be transposed except TypeError: # not iterable raise TypeError(err_msg) elif hasattr(tabular_data, "index"): # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) keys = list(tabular_data) - if ( - showindex in {"default", "always", True} - and tabular_data.index.name is not None - ): + if showindex in {"default", "always", True} and tabular_data.index.name is not None: if isinstance(tabular_data.index.name, list): keys[:0] = tabular_data.index.name else: @@ -596,19 +535,10 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): if headers == "keys" and not rows: # an empty table (issue #81) headers = [] - elif ( - headers == "keys" - and hasattr(tabular_data, "dtype") - and getattr(tabular_data.dtype, "names") - ): + elif headers == "keys" and hasattr(tabular_data, "dtype") and tabular_data.dtype.names: # numpy record array headers = tabular_data.dtype.names - elif ( - headers == "keys" - and len(rows) > 0 - and isinstance(rows[0], tuple) - and hasattr(rows[0], "_fields") - ): + elif headers == "keys" and len(rows) > 0 and isinstance(rows[0], tuple) and hasattr(rows[0], "_fields"): # namedtuple headers = list(map(str, rows[0]._fields)) elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"): @@ -639,9 +569,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): else: headers = [] elif headers: - raise ValueError( - "headers for a list of dicts is not a dict or a keyword" - ) + raise ValueError("headers for a list of dicts is not a dict or a keyword") rows = [[row.get(k) for k in keys] for row in rows] elif ( @@ -654,11 +582,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): # print tabulate(cursor, headers='keys') headers = [column[0] for column in tabular_data.description] - elif ( - dataclasses is not None - and len(rows) > 0 - and dataclasses.is_dataclass(rows[0]) - ): + elif dataclasses is not None and len(rows) > 0 and dataclasses.is_dataclass(rows[0]): # Python's dataclass field_names = [field.name for field in dataclasses.fields(rows[0])] if headers == "keys": @@ -698,6 +622,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): return rows, headers, headers_pad + def _to_str(s, encoding="utf8", errors="ignore"): if isinstance(s, bytes): return s.decode(encoding=encoding, errors=errors) @@ -727,9 +652,7 @@ def tabulate( if tabular_data is None: tabular_data = [] - list_of_lists, headers, headers_pad = _normalize_tabular_data( - tabular_data, headers, showindex=showindex - ) + list_of_lists, headers, headers_pad = _normalize_tabular_data(tabular_data, headers, showindex=showindex) list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) # PrettyTable formatting does not use any extra padding. @@ -771,11 +694,7 @@ def tabulate( has_invisible = _ansi_codes.search(plain_text) is not None enable_widechars = wcwidth is not None and WIDE_CHARS_MODE - if ( - not isinstance(tablefmt, TableFormat) - and tablefmt in multiline_formats - and _is_multiline(plain_text) - ): + if not isinstance(tablefmt, TableFormat) and tablefmt in multiline_formats and _is_multiline(plain_text): tablefmt = multiline_formats.get(tablefmt, tablefmt) is_multiline = True else: @@ -787,17 +706,13 @@ def tabulate( numparses = _expand_numparse(disable_numparse, len(cols)) coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] if isinstance(floatfmt, str): # old version - float_formats = len(cols) * [ - floatfmt - ] # just duplicate the string to use in each column + float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column else: # if floatfmt is list, tuple etc we have one per column float_formats = list(floatfmt) if len(float_formats) < len(cols): float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) if isinstance(intfmt, str): # old version - int_formats = len(cols) * [ - intfmt - ] # just duplicate the string to use in each column + int_formats = len(cols) * [intfmt] # just duplicate the string to use in each column else: # if intfmt is list, tuple etc we have one per column int_formats = list(intfmt) if len(int_formats) < len(cols): @@ -810,9 +725,7 @@ def tabulate( missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) cols = [ [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] - for c, ct, fl_fmt, int_fmt, miss_v in zip( - cols, coltypes, float_formats, int_formats, missing_vals - ) + for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals) ] # align columns @@ -833,26 +746,16 @@ def tabulate( for idx, align in enumerate(colalign): if not idx < len(aligns): break - elif align != "global": + if align != "global": aligns[idx] = align - minwidths = ( - [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) - ) + minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) aligns_copy = aligns.copy() # Reset alignments in copy of alignments list to "left" for 'colon_grid' format, # which enforces left alignment in the text output of the data. if tablefmt == "colon_grid": aligns_copy = ["left"] * len(cols) cols = [ - _align_column( - c, - a, - minw, - has_invisible, - enable_widechars, - is_multiline, - preserve_whitespace, - ) + _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline, preserve_whitespace) for c, a, minw in zip(cols, aligns_copy, minwidths) ] @@ -879,14 +782,11 @@ def tabulate( hidx = headers_pad + idx if not hidx < len(aligns_headers): break - elif align == "same" and hidx < len(aligns): # same as column align + if align == "same" and hidx < len(aligns): # same as column align aligns_headers[hidx] = aligns[hidx] elif align != "global": aligns_headers[hidx] = align - minwidths = [ - max(minw, max(width_fn(cl) for cl in c)) - for minw, c in zip(minwidths, t_cols) - ] + minwidths = [max(minw, max(width_fn(cl) for cl in c)) for minw, c in zip(minwidths, t_cols)] headers = [ _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) for h, a, minw in zip(headers, aligns_headers, minwidths) @@ -901,16 +801,7 @@ def tabulate( ra_default = rowalign if isinstance(rowalign, str) else None rowaligns = _expand_iterable(rowalign, len(rows), ra_default) - return _format_table( - tablefmt, - headers, - aligns_headers, - rows, - minwidths, - aligns, - is_multiline, - rowaligns=rowaligns, - ) + return _format_table(tablefmt, headers, aligns_headers, rows, minwidths, aligns, is_multiline, rowaligns=rowaligns) def _expand_numparse(disable_numparse, column_count): @@ -919,15 +810,13 @@ def _expand_numparse(disable_numparse, column_count): for index in disable_numparse: numparses[index] = False return numparses - else: - return [not disable_numparse] * column_count + return [not disable_numparse] * column_count def _expand_iterable(original, num_desired, default): if isinstance(original, Iterable) and not isinstance(original, str): return original + [default] * (num_desired - len(original)) - else: - return [default] * num_desired + return [default] * num_desired def _pad_row(cells, padding): @@ -937,8 +826,7 @@ def _pad_row(cells, padding): pad = " " * padding padded_cells = [pad + cell + pad for cell in cells] return padded_cells - else: - return cells + return cells def _build_simple_row(padded_cells, rowfmt): @@ -949,35 +837,34 @@ def _build_simple_row(padded_cells, rowfmt): def _build_row(padded_cells, colwidths, colaligns, rowfmt): if not rowfmt: return None - if hasattr(rowfmt, "__call__"): + if callable(rowfmt): return rowfmt(padded_cells, colwidths, colaligns) - else: - return _build_simple_row(padded_cells, rowfmt) + return _build_simple_row(padded_cells, rowfmt) + def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalign=None): # NOTE: rowalign is ignored and exists for api compatibility with _append_multiline_row lines.append(_build_row(padded_cells, colwidths, colaligns, rowfmt)) return lines + def _build_line(colwidths, colaligns, linefmt): - "Return a string which represents a horizontal line." + """Return a string which represents a horizontal line.""" if not linefmt: return None - if hasattr(linefmt, "__call__"): + if callable(linefmt): return linefmt(colwidths, colaligns) - else: - begin, fill, sep, end = linefmt - cells = [fill * w for w in colwidths] - return _build_simple_row(cells, (begin, sep, end)) + begin, fill, sep, end = linefmt + cells = [fill * w for w in colwidths] + return _build_simple_row(cells, (begin, sep, end)) def _append_line(lines, colwidths, colaligns, linefmt): lines.append(_build_line(colwidths, colaligns, linefmt)) return lines -def _format_table( - fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns -): + +def _format_table(fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns): lines = [] hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] pad = fmt.padding @@ -1001,31 +888,13 @@ def _format_table( # initial rows with a line below for row, ralign in zip(rows[:-1], rowaligns): if row != SEPARATING_LINE: - append_row( - lines, - pad_row(row, pad), - padded_widths, - colaligns, - fmt.datarow, - rowalign=ralign, - ) + append_row(lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow, rowalign=ralign) _append_line(lines, padded_widths, colaligns, fmt.linebetweenrows) # the last row without a line below - append_row( - lines, - pad_row(rows[-1], pad), - padded_widths, - colaligns, - fmt.datarow, - rowalign=rowaligns[-1], - ) + append_row(lines, pad_row(rows[-1], pad), padded_widths, colaligns, fmt.datarow, rowalign=rowaligns[-1]) else: separating_line = ( - fmt.linebetweenrows - or fmt.linebelowheader - or fmt.linebelow - or fmt.lineabove - or Line("", "", "", "") + fmt.linebetweenrows or fmt.linebelowheader or fmt.linebelow or fmt.lineabove or Line("", "", "", "") ) for row in rows: # test to see if either the 1st column or the 2nd column (account for showindex) has @@ -1033,9 +902,7 @@ def _format_table( if _is_separating_line(row): _append_line(lines, padded_widths, colaligns, separating_line) else: - append_row( - lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow - ) + append_row(lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow) if fmt.linebelow and "linebelow" not in hidden: _append_line(lines, padded_widths, colaligns, fmt.linebelow) @@ -1043,5 +910,5 @@ def _format_table( if headers or rows: output = "\n".join(lines) return output - else: # a completely empty table - return "" + # a completely empty table + return "" diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index ad3b5f642..aaf74fc93 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -24,14 +24,14 @@ def humanize_runtime(time_in_ns: int) -> str: runtime_human = "%.3g" % (time_micro / (1000**2)) elif units in {"minutes", "minute"}: runtime_human = "%.3g" % (time_micro / (60 * 1000**2)) - elif units in {"hour", "hours"}: #hours + elif units in {"hour", "hours"}: # hours runtime_human = "%.3g" % (time_micro / (3600 * 1000**2)) - else: #days - runtime_human = "%.3g" % (time_micro / (24*3600 * 1000**2)) + else: # days + runtime_human = "%.3g" % (time_micro / (24 * 3600 * 1000**2)) runtime_human_parts = str(runtime_human).split(".") if len(runtime_human_parts[0]) == 1: - if runtime_human_parts[0]=='1' and len(runtime_human_parts)>1: - units = units+'s' + if runtime_human_parts[0] == "1" and len(runtime_human_parts) > 1: + units = units + "s" if len(runtime_human_parts) == 1: runtime_human = f"{runtime_human_parts[0]}.00" elif len(runtime_human_parts[1]) >= 2: diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index bfcbbaead..73e7b399e 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -303,7 +303,7 @@ def mark_as_used_recursively(self, name: str) -> None: def remove_unused_definitions_recursively( - node: cst.CSTNode, definitions: dict[str, UsageInfo] + node: cst.CSTNode, definitions: dict[str, UsageInfo] ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node to remove unused definitions. @@ -358,7 +358,10 @@ def remove_unused_definitions_recursively( names = extract_names_from_targets(target.target) for name in names: class_var_name = f"{class_name}.{name}" - if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + if ( + class_var_name in definitions + and definitions[class_var_name].used_by_qualified_function + ): var_used = True method_or_var_used = True break diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8f5ba65eb..cb5c5d5a1 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -158,9 +158,9 @@ def get_functions_to_optimize( module_root: Path, previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: - assert ( - sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1 - ), "Only one of optimize_all, replay_test, or file should be provided" + assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( + "Only one of optimize_all, replay_test, or file should be provided" + ) functions: dict[str, list[FunctionToOptimize]] with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=SyntaxWarning) @@ -208,7 +208,7 @@ def get_functions_to_optimize( three_min_in_ns = int(1.8e11) console.rule() logger.info( - f"It might take about {humanize_runtime(functions_count*three_min_in_ns)} to fully optimize this project. Codeflash " + f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash " f"will keep opening pull requests as it finds optimizations." ) return filtered_modified_functions, functions_count diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index 1e66c5608..d87b28bc2 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import Union, Optional + +from typing import Optional, Union from pydantic import BaseModel from pydantic.dataclasses import dataclass from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.models.models import BenchmarkDetail -from codeflash.models.models import TestResults +from codeflash.models.models import BenchmarkDetail, TestResults @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 32add0a94..b250d2474 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -59,24 +59,29 @@ class FunctionSource: def __eq__(self, other: object) -> bool: if not isinstance(other, FunctionSource): return False - return (self.file_path == other.file_path and - self.qualified_name == other.qualified_name and - self.fully_qualified_name == other.fully_qualified_name and - self.only_function_name == other.only_function_name and - self.source_code == other.source_code) + return ( + self.file_path == other.file_path + and self.qualified_name == other.qualified_name + and self.fully_qualified_name == other.fully_qualified_name + and self.only_function_name == other.only_function_name + and self.source_code == other.source_code + ) def __hash__(self) -> int: - return hash((self.file_path, self.qualified_name, self.fully_qualified_name, - self.only_function_name, self.source_code)) + return hash( + (self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code) + ) + class BestOptimization(BaseModel): candidate: OptimizedCandidate helper_functions: list[FunctionSource] runtime: int - replay_performance_gain: Optional[dict[BenchmarkKey,float]] = None + replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None winning_behavioral_test_results: TestResults winning_benchmarking_test_results: TestResults - winning_replay_benchmarking_test_results : Optional[TestResults] = None + winning_replay_benchmarking_test_results: Optional[TestResults] = None + @dataclass(frozen=True) class BenchmarkKey: @@ -86,6 +91,7 @@ class BenchmarkKey: def __str__(self) -> str: return f"{self.module_path}::{self.function_name}" + @dataclass class BenchmarkDetail: benchmark_name: str @@ -107,9 +113,10 @@ def to_dict(self) -> dict[str, any]: "test_function": self.test_function, "original_timing": self.original_timing, "expected_new_timing": self.expected_new_timing, - "speedup_percent": self.speedup_percent + "speedup_percent": self.speedup_percent, } + @dataclass class ProcessedBenchmarkInfo: benchmark_details: list[BenchmarkDetail] @@ -124,9 +131,9 @@ def to_string(self) -> str: return result def to_dict(self) -> dict[str, list[dict[str, any]]]: - return { - "benchmark_details": [detail.to_dict() for detail in self.benchmark_details] - } + return {"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]} + + class CodeString(BaseModel): code: Annotated[str, AfterValidator(validate_python_code)] file_path: Optional[Path] = None @@ -151,7 +158,8 @@ class CodeOptimizationContext(BaseModel): read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] + class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE" @@ -347,6 +355,7 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt status=CoverageStatus.NOT_FOUND, ) + @dataclass class FunctionCoverage: """Represents the coverage data for a specific function in a source file.""" @@ -364,7 +373,8 @@ class TestingMode(enum.Enum): PERFORMANCE = "performance" LINE_PROFILE = "line_profile" -#TODO this class is duplicated in codeflash_capture + +# TODO this class is duplicated in codeflash_capture class VerificationType(str, Enum): FUNCTION_CALL = ( "function_call" # Correctness verification for a test function, checks input values and output values) @@ -473,14 +483,20 @@ def merge(self, other: TestResults) -> None: raise ValueError(msg) self.test_result_idx[k] = v + original_len - def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path) -> dict[BenchmarkKey, TestResults]: + def group_by_benchmarks( + self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path + ) -> dict[BenchmarkKey, TestResults]: """Group TestResults by benchmark for calculating improvements for each benchmark.""" test_results_by_benchmark = defaultdict(TestResults) benchmark_module_path = {} for benchmark_key in benchmark_keys: - benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root) + benchmark_module_path[benchmark_key] = module_name_from_file_path( + benchmark_replay_test_dir.resolve() + / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", + project_root, + ) for test_result in self.test_results: - if (test_result.test_type == TestType.REPLAY_TEST): + if test_result.test_type == TestType.REPLAY_TEST: for benchmark_key, module_path in benchmark_module_path.items(): if test_result.id.test_module_path.startswith(module_path): test_results_by_benchmark[benchmark_key].add(test_result) @@ -559,7 +575,7 @@ def total_passed_runtime(self) -> int: :return: The runtime in nanoseconds. """ - #TODO this doesn't look at the intersection of tests of baseline and original + # TODO this doesn't look at the intersection of tests of baseline and original return sum( [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] ) @@ -589,7 +605,7 @@ def __eq__(self, other: object) -> bool: if len(self) != len(other): return False original_recursion_limit = sys.getrecursionlimit() - cast(TestResults, other) + cast("TestResults", other) for test_result in self: other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) if other_test_result is None: diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index d55aa2dec..756c26095 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -31,7 +31,10 @@ def belongs_to_class(name: Name, class_name: str) -> bool: def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool: """Check if the given jedi Name is a direct child of the specified function, matched by qualified function name.""" try: - if name.full_name.startswith(name.module_name) and get_qualified_name(name.module_name, name.full_name) == qualified_function_name: + if ( + name.full_name.startswith(name.module_name) + and get_qualified_name(name.module_name, name.full_name) == qualified_function_name + ): # Handles function definition and recursive function calls return False if name := name.parent(): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 56124a9cb..ab9dd1017 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,7 +3,6 @@ import ast import concurrent.futures import os -import shutil import subprocess import time import uuid diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index de2cc1740..8860642fa 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -18,12 +18,12 @@ from codeflash.code_utils import env_utils from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.code_utils.code_replacer import normalize_code, normalize_node -from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file +from codeflash.code_utils.code_utils import cleanup_paths from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.either import is_successful -from codeflash.models.models import BenchmarkKey, TestType, ValidCode +from codeflash.models.models import BenchmarkKey, ValidCode from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig @@ -266,7 +266,14 @@ def run(self) -> None: if function_optimizer: function_optimizer.cleanup_generated_files() + if self.test_cfg.concolic_test_root_dir: + cleanup_paths([self.test_cfg.concolic_test_root_dir]) + def run_with_args(args: Namespace) -> None: - optimizer = Optimizer(args) - optimizer.run() + try: + optimizer = Optimizer(args) + optimizer.run() + except KeyboardInterrupt: + logger.warning("Keyboard interrupt received. Exiting, please wait…") + raise SystemExit from None diff --git a/codeflash/picklepatch/pickle_patcher.py b/codeflash/picklepatch/pickle_patcher.py index cfedd28fd..ef42285be 100644 --- a/codeflash/picklepatch/pickle_patcher.py +++ b/codeflash/picklepatch/pickle_patcher.py @@ -5,7 +5,6 @@ """ import pickle -import types import dill @@ -34,6 +33,7 @@ def dumps(obj, protocol=None, max_depth=100, **kwargs): Returns: bytes: Pickled data with placeholders for unpicklable objects + """ return PicklePatcher._recursive_pickle(obj, max_depth, path=[], protocol=protocol, **kwargs) @@ -46,11 +46,12 @@ def loads(pickled_data): Returns: The unpickled object with placeholders for unpicklable parts + """ try: # We use dill for loading since it can handle everything pickle can return dill.loads(pickled_data) - except Exception as e: + except Exception: raise @staticmethod @@ -64,6 +65,7 @@ def _create_placeholder(obj, error_msg, path): Returns: PicklePlaceholder: A placeholder object + """ obj_type = type(obj) try: @@ -73,12 +75,7 @@ def _create_placeholder(obj, error_msg, path): print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}") - placeholder = PicklePlaceholder( - obj_type.__name__, - obj_str, - error_msg, - path - ) + placeholder = PicklePlaceholder(obj_type.__name__, obj_str, error_msg, path) # Add this type to our known unpicklable types cache PicklePatcher._unpicklable_types.add(obj_type) @@ -98,11 +95,12 @@ def _pickle(obj, path=None, protocol=None, **kwargs): tuple: (success, result) where success is a boolean and result is either: - Pickled bytes if successful - Error message if not successful + """ # Try standard pickle first try: return True, pickle.dumps(obj, protocol=protocol, **kwargs) - except (pickle.PickleError, TypeError, AttributeError, ValueError) as e: + except (pickle.PickleError, TypeError, AttributeError, ValueError): # Then try dill (which is more powerful) try: return True, dill.dumps(obj, protocol=protocol, **kwargs) @@ -122,6 +120,7 @@ def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs): Returns: bytes: Pickled data with placeholders for unpicklable objects + """ if path is None: path = [] @@ -130,20 +129,12 @@ def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs): # Check if this type is known to be unpicklable if obj_type in PicklePatcher._unpicklable_types: - placeholder = PicklePatcher._create_placeholder( - obj, - "Known unpicklable type", - path - ) + placeholder = PicklePatcher._create_placeholder(obj, "Known unpicklable type", path) return dill.dumps(placeholder, protocol=protocol, **kwargs) # Check for max depth if max_depth <= 0: - placeholder = PicklePatcher._create_placeholder( - obj, - "Max recursion depth exceeded", - path - ) + placeholder = PicklePatcher._create_placeholder(obj, "Max recursion depth exceeded", path) return dill.dumps(placeholder, protocol=protocol, **kwargs) # Try standard pickling @@ -156,9 +147,9 @@ def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs): # Handle different container types if isinstance(obj, dict): return PicklePatcher._handle_dict(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) - elif isinstance(obj, (list, tuple, set)): + if isinstance(obj, (list, tuple, set)): return PicklePatcher._handle_sequence(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) - elif hasattr(obj, "__dict__"): + if hasattr(obj, "__dict__"): result = PicklePatcher._handle_object(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) # If this was a failure, add the type to the cache @@ -185,12 +176,11 @@ def _handle_dict(obj_dict, max_depth, error_msg, path, protocol=None, **kwargs): Returns: bytes: Pickled data with placeholders for unpicklable objects + """ if not isinstance(obj_dict, dict): placeholder = PicklePatcher._create_placeholder( - obj_dict, - f"Expected a dictionary, got {type(obj_dict).__name__}", - path + obj_dict, f"Expected a dictionary, got {type(obj_dict).__name__}", path ) return dill.dumps(placeholder, protocol=protocol, **kwargs) @@ -223,11 +213,7 @@ def _handle_dict(obj_dict, max_depth, error_msg, path, protocol=None, **kwargs): ) value_result = dill.loads(value_bytes) except Exception as inner_e: - value_result = PicklePatcher._create_placeholder( - value, - str(inner_e), - value_path - ) + value_result = PicklePatcher._create_placeholder(value, str(inner_e), value_path) result[key_result] = value_result @@ -247,6 +233,7 @@ def _handle_sequence(obj_seq, max_depth, error_msg, path, protocol=None, **kwarg Returns: bytes: Pickled data with placeholders for unpicklable objects + """ result = [] @@ -267,11 +254,7 @@ def _handle_sequence(obj_seq, max_depth, error_msg, path, protocol=None, **kwarg result.append(dill.loads(item_bytes)) except Exception as inner_e: # If recursive pickling fails, use a placeholder - placeholder = PicklePatcher._create_placeholder( - item, - str(inner_e), - item_path - ) + placeholder = PicklePatcher._create_placeholder(item, str(inner_e), item_path) result.append(placeholder) # Convert back to the original type @@ -301,6 +284,7 @@ def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs): Returns: bytes: Pickled data with placeholders for unpicklable objects + """ # Try to create a new instance of the same class try: @@ -326,11 +310,7 @@ def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs): setattr(new_obj, attr_name, dill.loads(attr_bytes)) except Exception as inner_e: # Use placeholder for unpicklable attribute - placeholder = PicklePatcher._create_placeholder( - attr_value, - str(inner_e), - attr_path - ) + placeholder = PicklePatcher._create_placeholder(attr_value, str(inner_e), attr_path) setattr(new_obj, attr_name, placeholder) # Try to pickle the patched object @@ -343,4 +323,4 @@ def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs): # If we get here, just use a placeholder placeholder = PicklePatcher._create_placeholder(obj, error_msg, path) - return dill.dumps(placeholder, protocol=protocol, **kwargs) \ No newline at end of file + return dill.dumps(placeholder, protocol=protocol, **kwargs) diff --git a/codeflash/picklepatch/pickle_placeholder.py b/codeflash/picklepatch/pickle_placeholder.py index 0d730dabb..49c5d755c 100644 --- a/codeflash/picklepatch/pickle_placeholder.py +++ b/codeflash/picklepatch/pickle_placeholder.py @@ -2,7 +2,6 @@ class PicklePlaceholderAccessError(Exception): """Custom exception raised when attempting to access an unpicklable object.""" - class PicklePlaceholder: """A placeholder for an object that couldn't be pickled. @@ -62,10 +61,5 @@ def __reduce__(self): """Make sure pickling of the placeholder itself works correctly.""" return ( PicklePlaceholder, - ( - self.__dict__["obj_type"], - self.__dict__["obj_str"], - self.__dict__["error_msg"], - self.__dict__["path"] - ) + (self.__dict__["obj_type"], self.__dict__["obj_str"], self.__dict__["error_msg"], self.__dict__["path"]), ) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 502c811eb..8524d397e 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -78,7 +78,7 @@ def check_create_pr( speedup_pct=explanation.speedup_pct, winning_behavioral_test_results=explanation.winning_behavioral_test_results, winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, - benchmark_details=explanation.benchmark_details + benchmark_details=explanation.benchmark_details, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, @@ -125,7 +125,7 @@ def check_create_pr( speedup_pct=explanation.speedup_pct, winning_behavioral_test_results=explanation.winning_behavioral_test_results, winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, - benchmark_details=explanation.benchmark_details + benchmark_details=explanation.benchmark_details, ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index c6e1fb9dc..bfb061cec 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -77,23 +77,23 @@ def to_console_string(self) -> str: test_function, f"{detail.original_timing}", f"{detail.expected_new_timing}", - f"{detail.speedup_percent:.2f}%" + f"{detail.speedup_percent:.2f}%", ) # Convert table to string string_buffer = StringIO() console = Console(file=string_buffer, width=terminal_width) console.print(table) - benchmark_info = cast(StringIO, console.file).getvalue() + "\n" # Cast for mypy + benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy return ( - f"Optimized {self.function_name} in {self.file_path}\n" - f"{self.perf_improvement_line}\n" - f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" - + (benchmark_info if benchmark_info else "") - + self.raw_explanation_message - + " \n\n" - + "The new optimized code was tested for correctness. The results are listed below.\n" - + f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n" + f"Optimized {self.function_name} in {self.file_path}\n" + f"{self.perf_improvement_line}\n" + f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + + (benchmark_info if benchmark_info else "") + + self.raw_explanation_message + + " \n\n" + + "The new optimized code was tested for correctness. The results are listed below.\n" + + f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n" ) def explanation_message(self) -> str: diff --git a/codeflash/verification/codeflash_capture.py b/codeflash/verification/codeflash_capture.py index 4fd2cf079..2695c7539 100644 --- a/codeflash/verification/codeflash_capture.py +++ b/codeflash/verification/codeflash_capture.py @@ -7,10 +7,12 @@ import os import sqlite3 import time -from pathlib import Path from enum import Enum +from pathlib import Path + import dill as pickle + class VerificationType(str, Enum): FUNCTION_CALL = ( "function_call" # Correctness verification for a test function, checks input values and output values) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 79ae7776c..53ef74054 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -84,7 +84,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: frozenset, enum.Enum, type, - range + range, ), ): return orig == new diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index c8b6053a0..994e8d949 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -1,10 +1,9 @@ from __future__ import annotations -import time - import ast import subprocess import tempfile +import time from argparse import Namespace from pathlib import Path diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index b7ce6978a..9d7f5ba2c 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -70,9 +70,12 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR are_equal = False break - if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST} and ( - cdd_test_result.did_pass != original_test_result.did_pass - ): + if original_test_result.test_type in { + TestType.EXISTING_UNIT_TEST, + TestType.CONCOLIC_COVERAGE_TEST, + TestType.GENERATED_REGRESSION, + TestType.REPLAY_TEST, + } and (cdd_test_result.did_pass != original_test_result.did_pass): are_equal = False break sys.setrecursionlimit(original_recursion_limit) diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 5e753b932..454b239b6 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -1,88 +1,88 @@ """Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" -import linecache + import inspect -from codeflash.code_utils.tabulate import tabulate +import linecache import os -import dill as pickle from pathlib import Path from typing import Optional +import dill as pickle + +from codeflash.code_utils.tabulate import tabulate + + def show_func(filename, start_lineno, func_name, timings, unit): total_hits = sum(t[1] for t in timings) total_time = sum(t[2] for t in timings) out_table = "" table_rows = [] if total_hits == 0: - return '' + return "" scalar = 1 if os.path.exists(filename): - out_table += f'## Function: {func_name}\n' + out_table += f"## Function: {func_name}\n" # Clear the cache to ensure that we get up-to-date results. linecache.clearcache() all_lines = linecache.getlines(filename) - sublines = inspect.getblock(all_lines[start_lineno - 1:]) - out_table += '## Total time: %g s\n' % (total_time * unit) + sublines = inspect.getblock(all_lines[start_lineno - 1 :]) + out_table += "## Total time: %g s\n" % (total_time * unit) # Define minimum column sizes so text fits and usually looks consistent - default_column_sizes = { - 'hits': 9, - 'time': 12, - 'perhit': 8, - 'percent': 8, - } + default_column_sizes = {"hits": 9, "time": 12, "perhit": 8, "percent": 8} display = {} # Loop over each line to determine better column formatting. # Fallback to scientific notation if columns are larger than a threshold. for lineno, nhits, time in timings: if total_time == 0: # Happens rarely on empty function - percent = '' + percent = "" else: - percent = '%5.1f' % (100 * time / total_time) + percent = "%5.1f" % (100 * time / total_time) - time_disp = '%5.1f' % (time * scalar) - if len(time_disp) > default_column_sizes['time']: - time_disp = '%5.1g' % (time * scalar) - perhit_disp = '%5.1f' % (float(time) * scalar / nhits) - if len(perhit_disp) > default_column_sizes['perhit']: - perhit_disp = '%5.1g' % (float(time) * scalar / nhits) + time_disp = "%5.1f" % (time * scalar) + if len(time_disp) > default_column_sizes["time"]: + time_disp = "%5.1g" % (time * scalar) + perhit_disp = "%5.1f" % (float(time) * scalar / nhits) + if len(perhit_disp) > default_column_sizes["perhit"]: + perhit_disp = "%5.1g" % (float(time) * scalar / nhits) nhits_disp = "%d" % nhits - if len(nhits_disp) > default_column_sizes['hits']: - nhits_disp = '%g' % nhits + if len(nhits_disp) > default_column_sizes["hits"]: + nhits_disp = "%g" % nhits display[lineno] = (nhits_disp, time_disp, perhit_disp, percent) linenos = range(start_lineno, start_lineno + len(sublines)) - empty = ('', '', '', '') - table_cols = ('Hits', 'Time', 'Per Hit', '% Time', 'Line Contents') + empty = ("", "", "", "") + table_cols = ("Hits", "Time", "Per Hit", "% Time", "Line Contents") for lineno, line in zip(linenos, sublines): nhits, time, per_hit, percent = display.get(lineno, empty) - line_ = line.rstrip('\n').rstrip('\r') - if 'def' in line_ or nhits!='': + line_ = line.rstrip("\n").rstrip("\r") + if "def" in line_ or nhits != "": table_rows.append((nhits, time, per_hit, percent, line_)) - pass - out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True) - out_table+='\n' + out_table += tabulate( + headers=table_cols, tabular_data=table_rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + out_table += "\n" return out_table + def show_text(stats: dict) -> str: - """ Show text for the given timings. - """ + """Show text for the given timings.""" out_table = "" - out_table += '# Timer unit: %g s\n' % stats['unit'] - stats_order = sorted(stats['timings'].items()) + out_table += "# Timer unit: %g s\n" % stats["unit"] + stats_order = sorted(stats["timings"].items()) # Show detailed per-line information for each function. for (fn, lineno, name), timings in stats_order: - table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit']) + table_md = show_func(fn, lineno, name, stats["timings"][fn, lineno, name], stats["unit"]) out_table += table_md return out_table + def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict: line_profiler_output_file = line_profiler_output_file.with_suffix(".lprof") stats_dict = {} if not line_profiler_output_file.exists(): - return {'timings':{},'unit':0, 'str_out':''}, None - else: - with open(line_profiler_output_file,'rb') as f: - stats = pickle.load(f) - stats_dict['timings'] = stats.timings - stats_dict['unit'] = stats.unit - str_out = show_text(stats_dict) - stats_dict['str_out'] = str_out - return stats_dict, None + return {"timings": {}, "unit": 0, "str_out": ""}, None + with open(line_profiler_output_file, "rb") as f: + stats = pickle.load(f) + stats_dict["timings"] = stats.timings + stats_dict["unit"] = stats.unit + str_out = show_text(stats_dict) + stats_dict["str_out"] = str_out + return stats_dict, None diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 9f78083a9..8d187f2b1 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -1,8 +1,7 @@ from __future__ import annotations -import time - import ast +import time from pathlib import Path from typing import TYPE_CHECKING