diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index a1e7da8ea..35e0a7d9b 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -32,7 +32,7 @@ jobs: run: uvx poetry install --with dev - name: Unit tests - run: uvx poetry run pytest tests/ --cov --cov-report=xml + run: uvx poetry run pytest tests/ --cov --cov-report=xml -vv - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/codeflash/LICENSE b/codeflash/LICENSE index 8b94a373d..d32df80d3 100644 --- a/codeflash/LICENSE +++ b/codeflash/LICENSE @@ -3,7 +3,7 @@ Business Source License 1.1 Parameters Licensor: CodeFlash Inc. -Licensed Work: Codeflash Client version 0.9.x +Licensed Work: Codeflash Client version 0.10.x The Licensed Work is (c) 2024 CodeFlash Inc. Additional Use Grant: None. Production use of the Licensed Work is only permitted @@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte Platform. Please visit codeflash.ai for further information. -Change Date: 2029-01-06 +Change Date: 2029-02-25 Change License: MIT diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 2a9b977e2..90f58f515 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -220,19 +220,26 @@ def collect_setup_info() -> SetupInfo: carousel=True, ) + git_remote = "" try: repo = Repo(str(module_root), search_parent_directories=True) git_remotes = get_git_remotes(repo) - if len(git_remotes) > 1: - git_remote = inquirer_wrapper( - inquirer.list_input, - message="What git remote do you want Codeflash to use for new Pull Requests? ", - choices=git_remotes, - default="origin", - carousel=True, - ) + if git_remotes: # Only proceed if there are remotes + if len(git_remotes) > 1: + git_remote = inquirer_wrapper( + inquirer.list_input, + message="What git remote do you want Codeflash to use for new Pull Requests? ", + choices=git_remotes, + default="origin", + carousel=True, + ) + else: + git_remote = git_remotes[0] else: - git_remote = git_remotes[0] + click.echo( + "No git remotes found. You can still use Codeflash locally, but you'll need to set up a remote " + "repository to use GitHub features." + ) except InvalidGitRepositoryError: git_remote = "" @@ -587,6 +594,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: ) elif formatter == "don't use a formatter": formatter_cmds.append("disabled") + if formatter in ["black", "ruff"]: + try: + result = subprocess.run([formatter], capture_output=True, check=False) + except FileNotFoundError as e: + click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed") codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 9661e9509..3ef0f2eca 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from itertools import cycle -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING from rich.console import Console from rich.logging import RichHandler @@ -13,6 +13,8 @@ from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT if TYPE_CHECKING: + from collections.abc import Generator + from rich.progress import TaskID DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index 76a19f5e7..a6f6aa892 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -20,6 +20,13 @@ class PrComment: winning_benchmarking_test_results: TestResults def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: + + report_table = { + test_type.to_name(): result + for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() + if test_type.to_name() + } + return { "optimization_explanation": self.optimization_explanation, "best_runtime": humanize_runtime(self.best_runtime), @@ -29,10 +36,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: "speedup_x": self.speedup_x, "speedup_pct": self.speedup_pct, "loop_count": self.winning_benchmarking_test_results.number_of_loops(), - "report_table": { - test_type.to_name(): result - for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() - }, + "report_table": report_table } diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7b067a094..4356a4e63 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ import isort import libcst as cst +from auditwall.core import SideEffectDetected from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -21,7 +22,10 @@ from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code +from codeflash.code_utils.code_extractor import ( + add_needed_imports_from_module, + extract_code, +) from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -36,9 +40,15 @@ TOTAL_LOOPING_TIME, ) from codeflash.code_utils.formatter import format_code, sort_imports -from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests -from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast +from codeflash.code_utils.instrument_existing_tests import ( + inject_profiling_into_existing_test, +) +from codeflash.code_utils.remove_generated_tests import ( + remove_functions_from_generated_tests, +) +from codeflash.code_utils.static_analysis import ( + get_first_top_level_function_or_method_ast, +) from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -58,17 +68,29 @@ TestFiles, TestingMode, ) -from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions +from codeflash.optimization.function_context import ( + get_constrained_function_context_and_helper_functions, +) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for -from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic +from codeflash.result.critic import ( + coverage_critic, + performance_gain, + quantity_of_tests_critic, + speedup_critic, +) from codeflash.result.explanation import Explanation from codeflash.telemetry.posthog_cf import ph from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results -from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.verification.instrument_codeflash_capture import ( + instrument_codeflash_capture, +) from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.test_results import TestResults, TestType -from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests +from codeflash.verification.test_runner import ( + run_behavioral_tests, + run_benchmarking_tests, +) from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -96,7 +118,9 @@ def __init__( ) -> None: self.project_root = test_cfg.project_root_path self.test_cfg = test_cfg - self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient() + self.aiservice_client = ( + aiservice_client if aiservice_client else AiServiceClient() + ) self.function_to_optimize = function_to_optimize self.function_to_optimize_source_code = ( function_to_optimize_source_code @@ -106,19 +130,25 @@ def __init__( if not function_to_optimize_ast: original_module_ast = ast.parse(function_to_optimize_source_code) self.function_to_optimize_ast = get_first_top_level_function_or_method_ast( - function_to_optimize.function_name, function_to_optimize.parents, original_module_ast + function_to_optimize.function_name, + function_to_optimize.parents, + original_module_ast, ) else: self.function_to_optimize_ast = function_to_optimize_ast self.function_to_tests = function_to_tests if function_to_tests else {} self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) - self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None + self.local_aiservice_client = ( + LocalAiServiceClient() if self.experiment_id else None + ) self.test_files = TestFiles(test_files=[]) self.args = args # Check defaults for these self.function_trace_id: str = str(uuid.uuid4()) - self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root) + self.original_module_path = module_name_from_file_path( + self.function_to_optimize.file_path, self.project_root + ) def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None @@ -151,13 +181,20 @@ def optimize_function(self) -> Result[BestOptimization, str]: generated_test_paths = [ get_test_file_path( - self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="unit", ) for test_index in range(N_TESTS_TO_GENERATE) ] + generated_perf_test_paths = [ get_test_file_path( - self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="perf" + self.test_cfg.tests_root, + self.function_to_optimize.function_name, + test_index, + test_type="perf", ) for test_index in range(N_TESTS_TO_GENERATE) ] @@ -180,7 +217,12 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(generated_results.failure()) generated_tests: GeneratedTestsList optimizations_set: OptimizationSet - generated_tests, function_to_concolic_tests, concolic_test_str, optimizations_set = generated_results.unwrap() + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + ) = generated_results.unwrap() count_tests = len(generated_tests.generated_tests) if concolic_test_str: count_tests += 1 @@ -208,29 +250,39 @@ def optimize_function(self) -> Result[BestOptimization, str]: function_to_optimize_qualified_name = self.function_to_optimize.qualified_name function_to_all_tests = { - key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, []) + key: self.function_to_tests.get(key, []) + + function_to_concolic_tests.get(key, []) for key in set(self.function_to_tests) | set(function_to_concolic_tests) } - instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) + instrumented_unittests_created_for_function = self.instrument_existing_tests( + function_to_all_tests + ) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) for function_source in code_context.helper_functions: if ( - function_source.qualified_name != self.function_to_optimize.qualified_name + function_source.qualified_name + != self.function_to_optimize.qualified_name and "." in function_source.qualified_name ): - file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0]) + file_path_to_helper_classes[function_source.file_path].add( + function_source.qualified_name.split(".")[0] + ) - baseline_result = self.establish_original_code_baseline( # this needs better typing - code_context=code_context, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, + baseline_result = ( + self.establish_original_code_baseline( # this needs better typing + code_context=code_context, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + ) ) console.rule() paths_to_cleanup = ( - generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function) + generated_test_paths + + generated_perf_test_paths + + list(instrumented_unittests_created_for_function) ) if not is_successful(baseline_result): @@ -238,7 +290,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(baseline_result.failure()) original_code_baseline, test_functions_to_remove = baseline_result.unwrap() - if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( + if isinstance( + original_code_baseline, OriginalCodeBaseline + ) and not coverage_critic( original_code_baseline.coverage_results, self.args.test_framework ): cleanup_paths(paths_to_cleanup) @@ -246,7 +300,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_optimization = None - for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): + for u, candidates in enumerate( + [optimizations_set.control, optimizations_set.experiment] + ): if candidates is None: continue @@ -257,10 +313,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, ) - ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id}) + ph( + "cli-optimize-function-finished", + {"function_trace_id": self.function_trace_id}, + ) generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, ) if best_optimization: @@ -268,7 +328,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_print(best_optimization.candidate.source_code) console.print( Panel( - best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" + best_optimization.candidate.explanation, + title="Best Candidate Explanation", + border_style="blue", ) ) explanation = Explanation( @@ -284,21 +346,28 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.log_successful_optimization(explanation, generated_tests) self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=best_optimization.candidate.source_code + code_context=code_context, + optimized_code=best_optimization.candidate.source_code, ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + code_context.helper_functions, + explanation.file_path, + self.function_to_optimize_source_code, ) existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + self.function_to_optimize.qualified_name_with_modules_from_root( + self.project_root + ), function_to_all_tests, tests_root=self.test_cfg.tests_root, ) original_code_combined = original_helper_code.copy() - original_code_combined[explanation.file_path] = self.function_to_optimize_source_code + original_code_combined[explanation.file_path] = ( + self.function_to_optimize_source_code + ) new_code_combined = new_helper_code.copy() new_code_combined[explanation.file_path] = new_code if not self.args.no_pr: @@ -308,7 +377,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: else "Coverage data not available" ) generated_tests_str = "\n\n".join( - [test.generated_original_test_source for test in generated_tests.generated_tests] + [ + test.generated_original_test_source + for test in generated_tests.generated_tests + ] ) if concolic_test_str: generated_tests_str += "\n\n" + concolic_test_str @@ -345,7 +417,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: break # need to delete only one test directory if not best_optimization: - return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") + return Failure( + f"No best optimizations found for function {self.function_to_optimize.qualified_name}" + ) return Success(best_optimization) def determine_best_candidate( @@ -371,9 +445,15 @@ def determine_best_candidate( console.rule() try: for candidate_index, candidate in enumerate(candidates, start=1): - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:") + get_run_tmp_file( + Path(f"test_return_values_{candidate_index}.bin") + ).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{candidate_index}.sqlite") + ).unlink(missing_ok=True) + logger.info( + f"Optimization candidate {candidate_index}/{len(candidates)}:" + ) code_print(candidate.source_code) try: did_update = self.replace_function_and_helpers_with_optimized_code( @@ -385,10 +465,17 @@ def determine_best_candidate( ) console.rule() continue - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + except ( + ValueError, + SyntaxError, + cst.ParserSyntaxError, + AttributeError, + ) as e: logger.error(e) self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) continue @@ -411,16 +498,23 @@ def determine_best_candidate( optimized_runtimes[candidate.optimization_id] = best_test_runtime is_correct[candidate.optimization_id] = True perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + original_runtime_ns=original_code_baseline.runtime, + optimized_runtime_ns=best_test_runtime, ) speedup_ratios[candidate.optimization_id] = perf_gain tree = Tree(f"Candidate #{candidate_index} - Runtime Information") if speedup_critic( - candidate_result, original_code_baseline.runtime, best_runtime_until_now + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now, ) and quantity_of_tests_critic(candidate_result): - tree.add("This candidate is faster than the previous best candidate. 🚀") - tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") + tree.add( + "This candidate is faster than the previous best candidate. 🚀" + ) + tree.add( + f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}" + ) tree.add( f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " f"(measured over {candidate_result.max_loop_count} " @@ -449,11 +543,15 @@ def determine_best_candidate( console.rule() self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) except KeyboardInterrupt as e: self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) logger.exception(f"Optimization interrupted: {e}") raise @@ -467,7 +565,9 @@ def determine_best_candidate( ) return best_optimization - def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList) -> None: + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList + ) -> None: explanation_panel = Panel( f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n" f"📈 {explanation.perf_improvement_line}\n" @@ -479,7 +579,12 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: if self.args.no_pr: tests_panel = Panel( Syntax( - "\n".join([test.generated_original_test_source for test in generated_tests.generated_tests]), + "\n".join( + [ + test.generated_original_test_source + for test in generated_tests.generated_tests + ] + ), "python", line_numbers=True, ), @@ -506,7 +611,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: ) @staticmethod - def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, str], path: Path) -> None: + def write_code_and_helpers( + original_code: str, original_helper_code: dict[Path, str], path: Path + ) -> None: with path.open("w", encoding="utf8") as f: f.write(original_code) for module_abspath in original_helper_code: @@ -527,7 +634,9 @@ def reformat_code_and_helpers( new_helper_code: dict[Path, str] = {} helper_functions_paths = {hf.file_path for hf in helper_functions} for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + formatted_helper_code = format_code( + self.args.formatter_cmds, module_abspath + ) if should_sort_imports: formatted_helper_code = sort_imports(formatted_helper_code) new_helper_code[module_abspath] = formatted_helper_code @@ -544,8 +653,13 @@ def replace_function_and_helpers_with_optimized_code( ) for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": - read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) - for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): + read_writable_functions_by_file_path[helper_function.file_path].add( + helper_function.qualified_name + ) + for ( + module_abspath, + qualified_names, + ) in read_writable_functions_by_file_path.items(): did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), optimized_code=optimized_code, @@ -556,18 +670,23 @@ def replace_function_and_helpers_with_optimized_code( return did_update def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: - code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize]) + code_to_optimize, contextual_dunder_methods = extract_code( + [self.function_to_optimize] + ) if code_to_optimize is None: return Failure("Could not find function to optimize.") - (helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions( - self.function_to_optimize, self.project_root, code_to_optimize + (helper_code, helper_functions, helper_dunder_methods) = ( + get_constrained_function_context_and_helper_functions( + self.function_to_optimize, self.project_root, code_to_optimize + ) ) if self.function_to_optimize.parents: function_class = self.function_to_optimize.parents[0].name same_class_helper_methods = [ df for df in helper_functions - if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class + if df.qualified_name.count(".") > 0 + and df.qualified_name.split(".")[0] == function_class ] optimizable_methods = [ FunctionToOptimize( @@ -586,7 +705,9 @@ def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: dedup_optimizable_methods.append(method) added_methods.add(f"{method.file_path}.{method.qualified_name}") if len(dedup_optimizable_methods) > 1: - code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods))) + code_to_optimize, contextual_dunder_methods = extract_code( + list(reversed(dedup_optimizable_methods)) + ) if code_to_optimize is None: return Failure("Could not find function to optimize.") code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize @@ -623,23 +744,35 @@ def cleanup_leftover_test_return_values() -> None: get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) - def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]: + def instrument_existing_tests( + self, function_to_all_tests: dict[str, list[FunctionCalledInTest]] + ) -> set[Path]: existing_test_files_count = 0 replay_test_files_count = 0 concolic_coverage_test_files_count = 0 unique_instrumented_test_files = set() - func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root) + func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root( + self.project_root + ) if func_qualname not in function_to_all_tests: - logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.") + logger.info( + f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests." + ) console.rule() else: test_file_invocation_positions = defaultdict(list[FunctionCalledInTest]) for tests_in_file in function_to_all_tests.get(func_qualname): test_file_invocation_positions[ - (tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type) + ( + tests_in_file.tests_in_file.test_file, + tests_in_file.tests_in_file.test_type, + ) ].append(tests_in_file) - for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items(): + for ( + test_file, + test_type, + ), tests_in_file_list in test_file_invocation_positions.items(): path_obj_test_file = Path(test_file) if test_type == TestType.EXISTING_UNIT_TEST: existing_test_files_count += 1 @@ -718,9 +851,16 @@ def generate_tests_and_optimizations( generated_test_paths: list[Path], generated_perf_test_paths: list[Path], run_experiment: bool = False, - ) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]: + ) -> Result[ + tuple[ + GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet + ], + str, + ]: assert len(generated_test_paths) == N_TESTS_TO_GENERATE - max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 + max_workers = ( + N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 + ) console.rule() with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit the test generation task as future @@ -735,9 +875,17 @@ def generate_tests_and_optimizations( self.aiservice_client.optimize_python_code, read_writable_code, read_only_context_code, - self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, + ( + self.function_trace_id[:-4] + "EXP0" + if run_experiment + else self.function_trace_id + ), N_CANDIDATES, - ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, + ( + ExperimentMetadata(id=self.experiment_id, group="control") + if run_experiment + else None + ), ) future_candidates_exp = None @@ -748,7 +896,11 @@ def generate_tests_and_optimizations( self.function_to_optimize, self.function_to_optimize_ast, ) - futures = [*future_tests, future_optimization_candidates, future_concolic_tests] + futures = [ + *future_tests, + future_optimization_candidates, + future_concolic_tests, + ] if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, @@ -764,11 +916,17 @@ def generate_tests_and_optimizations( concurrent.futures.wait(futures) # Retrieve results - candidates: list[OptimizedCandidate] = future_optimization_candidates.result() + candidates: list[OptimizedCandidate] = ( + future_optimization_candidates.result() + ) if not candidates: - return Failure(f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}") + return Failure( + f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}" + ) - candidates_experiment = future_candidates_exp.result() if future_candidates_exp else None + candidates_experiment = ( + future_candidates_exp.result() if future_candidates_exp else None + ) # Process test generation results @@ -793,10 +951,18 @@ def generate_tests_and_optimizations( ) ) if not tests: - logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}") - return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}") - function_to_concolic_tests, concolic_test_str = future_concolic_tests.result() - logger.info(f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}") + logger.warning( + f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}" + ) + return Failure( + f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}" + ) + function_to_concolic_tests, concolic_test_str = ( + future_concolic_tests.result() + ) + logger.info( + f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}" + ) console.rule() generated_tests = GeneratedTestsList(generated_tests=tests) @@ -816,8 +982,13 @@ def establish_original_code_baseline( file_path_to_helper_classes: dict[Path, set[str]], ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: # For the original function - run the tests and get the runtime, plus coverage - with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"): - assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] + with progress_bar( + f"Establishing original code baseline for {self.function_to_optimize.function_name}" + ): + assert (test_framework := self.args.test_framework) in [ + "pytest", + "unittest", + ] success = True test_env = os.environ.copy() @@ -833,7 +1004,9 @@ def establish_original_code_baseline( # Instrument codeflash capture try: instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + self.function_to_optimize, + file_path_to_helper_classes, + self.test_cfg.tests_root, ) behavioral_results, coverage_results = self.run_and_parse_tests( testing_type=TestingMode.BEHAVIOR, @@ -844,20 +1017,26 @@ def establish_original_code_baseline( enable_coverage=test_framework == "pytest", code_context=code_context, ) + except SideEffectDetected as e: + return Failure( + f"Side effect detected in original code: {e}, skipping optimization." + ) finally: # Remove codeflash capture self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) if not behavioral_results: logger.warning( f"Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION." ) console.rule() - return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") - if not coverage_critic( - coverage_results, self.args.test_framework - ): + return Failure( + "Failed to establish a baseline for the original code - bevhavioral tests failed." + ) + if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": benchmarking_results, _ = self.run_and_parse_tests( @@ -898,12 +1077,16 @@ def establish_original_code_baseline( ) console.rule() - - total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index + total_timing = ( + benchmarking_results.total_passed_runtime() + ) # caution: doesn't handle the loop index functions_to_remove = [ result.id.test_function_name for result in behavioral_results - if (result.test_type == TestType.GENERATED_REGRESSION and not result.did_pass) + if ( + result.test_type == TestType.GENERATED_REGRESSION + and not result.did_pass + ) ] if total_timing == 0: logger.warning( @@ -912,13 +1095,17 @@ def establish_original_code_baseline( console.rule() success = False if not total_timing: - logger.warning("Failed to run the tests for the original function, skipping optimization") + logger.warning( + "Failed to run the tests for the original function, skipping optimization" + ) console.rule() success = False if not success: return Failure("Failed to establish a baseline for the original code.") - loop_count = max([int(result.loop_index) for result in benchmarking_results.test_results]) + loop_count = max( + [int(result.loop_index) for result in benchmarking_results.test_results] + ) logger.info( f"Original code summed runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: " f"{humanize_runtime(total_timing)} per full loop" @@ -957,17 +1144,27 @@ def run_optimized_candidate( else: test_env["PYTHONPATH"] += os.pathsep + str(self.project_root) - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{optimization_candidate_index}.sqlite") + ).unlink(missing_ok=True) + get_run_tmp_file( + Path(f"test_return_values_{optimization_candidate_index}.sqlite") + ).unlink(missing_ok=True) # Instrument codeflash capture - candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text( + "utf-8" + ) candidate_helper_code = {} for module_abspath in original_helper_code: - candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") + candidate_helper_code[module_abspath] = Path(module_abspath).read_text( + "utf-8" + ) try: instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + self.function_to_optimize, + file_path_to_helper_classes, + self.test_cfg.tests_root, ) candidate_behavior_results, _ = self.run_and_parse_tests( @@ -981,7 +1178,9 @@ def run_optimized_candidate( # Remove instrumentation finally: self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path + candidate_fto_code, + candidate_helper_code, + self.function_to_optimize.file_path, ) console.print( TestResults.report_to_tree( @@ -991,13 +1190,19 @@ def run_optimized_candidate( ) console.rule() - if compare_test_results(baseline_results.behavioral_test_results, candidate_behavior_results): + if compare_test_results( + baseline_results.behavioral_test_results, candidate_behavior_results + ): logger.info("Test results matched!") console.rule() else: - logger.info("Test results did not match the test results of the original code.") + logger.info( + "Test results did not match the test results of the original code." + ) console.rule() - return Failure("Test results did not match the test results of the original code.") + return Failure( + "Test results did not match the test results of the original code." + ) if test_framework == "pytest": candidate_benchmarking_results, _ = self.run_and_parse_tests( @@ -1012,7 +1217,8 @@ def run_optimized_candidate( max(all_loop_indices) if ( all_loop_indices := { - result.loop_index for result in candidate_benchmarking_results.test_results + result.loop_index + for result in candidate_benchmarking_results.test_results } ) else 0 @@ -1038,11 +1244,17 @@ def run_optimized_candidate( loop_count = i + 1 candidate_benchmarking_results.merge(unittest_loop_results) - if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: - logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") + if ( + total_candidate_timing := candidate_benchmarking_results.total_passed_runtime() + ) == 0: + logger.warning( + "The overall test runtime of the optimized function is 0, couldn't run tests." + ) console.rule() - logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + logger.debug( + f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}" + ) return Success( OptimizedCandidateResult( max_loop_count=loop_count, @@ -1071,15 +1283,17 @@ def run_and_parse_tests( coverage_database_file = None try: if testing_type == TestingMode.BEHAVIOR: - result_file_path, run_result, coverage_database_file = run_behavioral_tests( - test_files, - test_framework=self.test_cfg.test_framework, - cwd=self.project_root, - test_env=test_env, - pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, - pytest_cmd=self.test_cfg.pytest_cmd, - verbose=True, - enable_coverage=enable_coverage, + result_file_path, run_result, coverage_database_file = ( + run_behavioral_tests( + test_files, + test_framework=self.test_cfg.test_framework, + cwd=self.project_root, + test_env=test_env, + pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + pytest_cmd=self.test_cfg.pytest_cmd, + verbose=True, + enable_coverage=enable_coverage, + ) ) elif testing_type == TestingMode.PERFORMANCE: result_file_path, run_result = run_benchmarking_tests( @@ -1097,13 +1311,13 @@ def run_and_parse_tests( raise ValueError(f"Unexpected testing type: {testing_type}") except subprocess.TimeoutExpired: logger.exception( - f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error' + f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error" ) return TestResults(), None if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR: logger.debug( - f'Nonzero return code {run_result.returncode} when running tests in ' - f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n' + f"Nonzero return code {run_result.returncode} when running tests in " + f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n" f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) @@ -1149,4 +1363,3 @@ def generate_and_instrument_tests( zip(generated_test_paths, generated_perf_test_paths) ) ] - diff --git a/codeflash/verification/_auditwall.py b/codeflash/verification/_auditwall.py new file mode 100644 index 000000000..3210ae7d0 --- /dev/null +++ b/codeflash/verification/_auditwall.py @@ -0,0 +1,58 @@ +# Copyright 2024 CodeFlash Inc. All rights reserved. +# +# Licensed under the Business Source License version 1.1. +# License source can be found in the LICENSE file. +# +# This file includes derived work covered by the following copyright and permission notices: +# +# Copyright Python Software Foundation +# Licensed under the Apache License, Version 2.0 (the "License"). +# http://www.apache.org/licenses/LICENSE-2.0 +# +# The PSF License Agreement +# https://docs.python.org/3/license.html#python-software-foundation-license-version-2 +# +# + +from auditwall.core import AuditWallConfig, _default_audit_wall, accept, reject + + +class CodeflashAuditWallConfig(AuditWallConfig): + def __init__(self) -> None: + super().__init__() + self.allowed_write_paths = {".coverage", "matplotlib.rc", "codeflash"} + + +def handle_os_remove(event: str, args: tuple) -> None: + filename = str(args[0]) + if any( + pattern in filename + for pattern in _default_audit_wall.config.allowed_write_paths + ): + accept(event, args) + else: + reject(event, args) + + +def check_sqlite_connect(event: str, args: tuple) -> None: + if ( + event == "sqlite3.connect" + and any( + pattern in str(args[0]) + for pattern in _default_audit_wall.config.allowed_write_paths + ) + ) or event == "sqlite3.connect/handle": + accept(event, args) + else: + reject(event, args) + + +custom_handlers = { + "os.remove": handle_os_remove, + "sqlite3.connect": check_sqlite_connect, + "sqlite3.connect/handle": check_sqlite_connect, +} + + +_default_audit_wall.config = CodeflashAuditWallConfig() +_default_audit_wall.config.special_handlers = custom_handlers diff --git a/codeflash/verification/codeflash_auditwall.py b/codeflash/verification/codeflash_auditwall.py new file mode 100644 index 000000000..f3e907407 --- /dev/null +++ b/codeflash/verification/codeflash_auditwall.py @@ -0,0 +1,27 @@ +import ast + + +class AuditWallTransformer(ast.NodeTransformer): + def visit_Module(self, node: ast.Module) -> ast.Module: # noqa: N802 + last_import_index = -1 + for i, body_node in enumerate(node.body): + if isinstance(body_node, (ast.Import, ast.ImportFrom)): + last_import_index = i + + new_import = ast.ImportFrom( + module="auditwall.core", names=[ast.alias(name="engage_auditwall")], level=0 + ) + function_call = ast.Expr( + value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[]) + ) + + node.body.insert(last_import_index + 1, new_import) + node.body.insert(last_import_index + 2, function_call) + + return node + +def transform_code(source_code: str) -> str: + tree = ast.parse(source_code) + transformer = AuditWallTransformer() + new_tree = transformer.visit(tree) + return ast.unparse(new_tree) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 46203e65a..21813fd87 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -1,16 +1,21 @@ from __future__ import annotations +import re import shlex import subprocess +import tempfile from pathlib import Path from typing import TYPE_CHECKING +from auditwall.core import SideEffectDetected + from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.coverage_utils import prepare_coverage_files from codeflash.models.models import TestFiles +from codeflash.verification.codeflash_auditwall import transform_code from codeflash.verification.test_results import TestType if TYPE_CHECKING: @@ -36,78 +41,96 @@ def run_behavioral_tests( pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME, enable_coverage: bool = False, ) -> tuple[Path, subprocess.CompletedProcess, Path | None]: - if test_framework == "pytest": - test_files: list[str] = [] - for file in test_paths.test_files: - if file.test_type == TestType.REPLAY_TEST: - # TODO: Does this work for unittest framework? - test_files.extend( - [ - str(file.instrumented_behavior_file_path) + "::" + test.test_function - for test in file.tests_in_file - ] - ) - else: - test_files.append(str(file.instrumented_behavior_file_path)) - test_files = list(set(test_files)) # remove multiple calls in the same test function - pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX) - - common_pytest_args = [ - "--capture=tee-sys", - f"--timeout={pytest_timeout}", - "-q", - "--codeflash_loops_scope=session", - "--codeflash_min_loops=1", - "--codeflash_max_loops=1", - f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO :This is unnecessary, update the plugin to not ask for this - ] - - result_file_path = get_run_tmp_file(Path("pytest_results.xml")) - result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] - - pytest_test_env = test_env.copy() - pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + if test_framework not in ["pytest", "unittest"]: + raise ValueError(f"Unsupported test framework: {test_framework}") - if enable_coverage: - coverage_database_file, coveragercfile = prepare_coverage_files() - - cov_erase = execute_test_subprocess( - shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env - ) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, - # then the current run will be appended to the previous data, which skews the results - logger.debug(cov_erase) - - results = execute_test_subprocess( - shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m") - + pytest_cmd_list - + common_pytest_args - + result_args - + test_files, - cwd=cwd, - env=pytest_test_env, - timeout=600, + test_files: list[str] = [] + for file in test_paths.test_files: + if file.test_type == TestType.REPLAY_TEST: + # TODO: Does this work for unittest framework? + test_files.extend( + [str(file.instrumented_behavior_file_path) + "::" + test.test_function for test in file.tests_in_file] ) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") else: - results = execute_test_subprocess( - pytest_cmd_list + common_pytest_args + result_args + test_files, - cwd=cwd, - env=pytest_test_env, - timeout=600, # TODO: Make this dynamic + test_files.append(str(file.instrumented_behavior_file_path)) + + source_code = next((file.original_source for file in test_paths.test_files if file.original_source), None) + if not source_code: + raise ValueError("No source code found for auditing") + + audit_code = transform_code(source_code) + pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX) + common_pytest_args = [ + "--capture=tee-sys", + f"--timeout={pytest_timeout}", + "-q", + "--codeflash_loops_scope=session", + "--codeflash_min_loops=1", + "--codeflash_max_loops=1", + f"--codeflash_seconds={pytest_target_runtime_seconds}", + "-p", + "no:cacheprovider", + ] + + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + + with tempfile.TemporaryDirectory( + dir=Path(test_paths.test_files[0].instrumented_behavior_file_path).parent + ) as temp_dir: + audited_file_path = Path(temp_dir) / "audited_code.py" + audited_file_path.write_text(audit_code, encoding="utf8") + + auditing_res = execute_test_subprocess( + pytest_cmd_list + common_pytest_args + [audited_file_path.as_posix()], + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + logger.info(auditing_res.stdout) + if auditing_res.returncode != 0: + line_co = next( + ( + line + for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines() + if "auditwall.core.SideEffectDetected" in line + ), + None, ) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") - elif test_framework == "unittest": + if line_co: + match = re.search(r"auditwall.core.SideEffectDetected: A (.+).", line_co) + if match: + msg = match.group(1) + raise SideEffectDetected(msg) + logger.debug(auditing_res.stderr) + logger.debug(auditing_res.stdout) + + if test_framework == "pytest": + coverage_database_file, coveragercfile = prepare_coverage_files() + cov_erase = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env + ) + logger.debug(cov_erase) + + results = execute_test_subprocess( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m") + + pytest_cmd_list + + common_pytest_args + + result_args + + list(set(test_files)), # remove duplicates + cwd=cwd, + env=pytest_test_env, + timeout=600, + ) + else: # unittest if enable_coverage: raise ValueError("Coverage is not supported yet for unittest framework") test_env["CODEFLASH_LOOP_INDEX"] = "1" test_files = [file.instrumented_behavior_file_path for file in test_paths.test_files] result_file_path, results = run_unittest_tests(verbose, test_files, test_env, cwd) - logger.debug( - f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""") - else: - raise ValueError(f"Unsupported test framework: {test_framework}") return result_file_path, results, coverage_database_file if enable_coverage else None diff --git a/codeflash/version.py b/codeflash/version.py index b29fd20bd..55232158e 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,3 +1,3 @@ # These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`. -__version__ = "0.9.2" -__version_tuple__ = (0, 9, 2) +__version__ = "0.10.0" +__version_tuple__ = (0, 10, 0) diff --git a/pyproject.toml b/pyproject.toml index 62587fa9d..27ebf6c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ exclude = [ # Versions here the minimum required versions for the project. These should be as loose as possible. [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.9" unidiff = ">=0.7.4" pytest = ">=7.0.0" gitpython = ">=3.1.31" diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 7f4a94845..88b46e87c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest + from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 83b1efd2b..7cfd7f782 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -457,6 +457,7 @@ def __init__(self, x=2): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -568,6 +569,7 @@ def __init__(self, *args, **kwargs): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -681,6 +683,7 @@ def __init__(self, x=2): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -831,6 +834,7 @@ def another_helper(self): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) @@ -967,6 +971,7 @@ def another_helper(self): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=test_code, ) ] ) diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 643d4bde7..8cfc06190 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -156,6 +156,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -328,6 +329,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -423,6 +425,7 @@ def sorter(self, arr): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index bf7373522..311774845 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -413,6 +413,7 @@ def test_sort(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -608,6 +609,7 @@ def test_sort_parametrized(input, expected_output): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=code, ) ] ) @@ -847,6 +849,7 @@ def test_sort_parametrized_loop(input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1156,6 +1159,7 @@ def test_sort(): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1425,6 +1429,7 @@ def test_sort(self): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1672,6 +1677,7 @@ def test_sort(self, input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -1923,6 +1929,7 @@ def test_sort(self): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) @@ -2172,6 +2179,7 @@ def test_sort(self, input, expected_output): test_type=TestType.EXISTING_UNIT_TEST, ) ], + original_source=code, ) ] ) diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index ee237cfca..62226b519 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -162,6 +162,7 @@ def test_single_element_list(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=fto_path.read_text("utf-8"), ) ] ) @@ -298,6 +299,7 @@ def test_single_element_list(): test_type=test_type, original_file_path=test_path, benchmarking_file_path=test_path_perf, + original_source=fto_path.read_text("utf-8"), ) ] ) diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index 60a4be70f..09a7ffb32 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -36,7 +36,13 @@ def test_sort(self): with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)] + test_files=[ + TestFile( + instrumented_behavior_file_path=Path(fp.name), + test_type=TestType.EXISTING_UNIT_TEST, + original_source=code, + ) + ] ) fp.write(code.encode("utf-8")) fp.flush() @@ -80,7 +86,13 @@ def test_sort(): with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)] + test_files=[ + TestFile( + instrumented_behavior_file_path=Path(fp.name), + test_type=TestType.EXISTING_UNIT_TEST, + original_source=code, + ) + ] ) fp.write(code.encode("utf-8")) fp.flush()