From e25a185fd300ddd9f6a513aa479b5dbe9dc91f03 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 19:31:01 -0700 Subject: [PATCH 1/9] first commit --- ...rated_tests.py => edit_generated_tests.py} | 0 codeflash/optimization/function_optimizer.py | 169 +++++++++++++++++- ...t_remove_functions_from_generated_tests.py | 3 +- 3 files changed, 163 insertions(+), 9 deletions(-) rename codeflash/code_utils/{remove_generated_tests.py => edit_generated_tests.py} (100%) diff --git a/codeflash/code_utils/remove_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py similarity index 100% rename from codeflash/code_utils/remove_generated_tests.py rename to codeflash/code_utils/edit_generated_tests.py diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 12aeff3fa..82e9d2c52 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -36,10 +36,10 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) +from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor @@ -265,10 +265,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 }, ) - generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove - ) - if best_optimization: logger.info("Best candidate:") code_print(best_optimization.candidate.source_code) @@ -295,8 +291,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, ) - self.log_successful_optimization(explanation, generated_tests, exp_type) - self.replace_function_and_helpers_with_optimized_code( code_context=code_context, optimized_code=best_optimization.candidate.source_code ) @@ -321,6 +315,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if original_code_baseline.coverage_results else "Coverage data not available" ) + generated_tests = remove_functions_from_generated_tests( + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + ) + # Add runtime comments to generated tests before creating the PR + generated_tests = self.add_runtime_comments_to_generated_tests( + generated_tests, + original_code_baseline.benchmarking_test_results, + best_optimization.winning_benchmarking_test_results, + ) generated_tests_str = "\n\n".join( [test.generated_original_test_source for test in generated_tests.generated_tests] ) @@ -345,6 +348,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 original_helper_code, self.function_to_optimize.file_path, ) + self.log_successful_optimization(explanation, generated_tests, exp_type) if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") @@ -1266,3 +1270,154 @@ def cleanup_generated_files(self) -> None: cleanup_paths(paths_to_cleanup) if hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir.cleanup() + + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_test_results: TestResults, + optimized_test_results: TestResults, + ) -> GeneratedTestsList: + """Add runtime performance comments to function calls in generated tests.""" + + def format_time(nanoseconds: int) -> str: + """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" + + def count_significant_digits(num: int) -> int: + """Count significant digits in an integer.""" + return len(str(abs(num))) + + def format_with_precision(value: float, unit: str) -> str: + """Format a value with 3 significant digits precision.""" + if value >= 100: + return f"{value:.0f}{unit}" + if value >= 10: + return f"{value:.1f}{unit}" + return f"{value:.2f}{unit}" + + if nanoseconds < 1_000: + return f"{nanoseconds}ns" + if nanoseconds < 1_000_000: + # Convert to microseconds + microseconds_int = nanoseconds // 1_000 + if count_significant_digits(microseconds_int) >= 3: + return f"{microseconds_int}μs" + microseconds_float = nanoseconds / 1_000 + return format_with_precision(microseconds_float, "μs") + if nanoseconds < 1_000_000_000: + # Convert to milliseconds + milliseconds_int = nanoseconds // 1_000_000 + if count_significant_digits(milliseconds_int) >= 3: + return f"{milliseconds_int}ms" + milliseconds_float = nanoseconds / 1_000_000 + return format_with_precision(milliseconds_float, "ms") + # Convert to seconds + seconds_int = nanoseconds // 1_000_000_000 + if count_significant_digits(seconds_int) >= 3: + return f"{seconds_int}s" + seconds_float = nanoseconds / 1_000_000_000 + return format_with_precision(seconds_float, "s") + + # Create dictionaries for fast lookup of runtime data + original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() + optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() + + class RuntimeCommentTransformer(cst.CSTTransformer): + def __init__(self): + self.in_test_function = False + self.current_test_name = None + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value.startswith("test_"): + self.in_test_function = True + self.current_test_name = node.name.value + else: + self.in_test_function = False + self.current_test_name = None + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + if original_node.name.value.startswith("test_"): + self.in_test_function = False + self.current_test_name = None + return updated_node + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> cst.SimpleStatementLine: + if not self.in_test_function or not self.current_test_name: + return updated_node + + # Look for assignment statements that assign to codeflash_output + # Handle both single statements and multiple statements on one line + codeflash_assignment_found = False + for stmt in updated_node.body: + if isinstance(stmt, cst.Assign): + if ( + len(stmt.targets) == 1 + and isinstance(stmt.targets[0].target, cst.Name) + and stmt.targets[0].target.value == "codeflash_output" + ): + codeflash_assignment_found = True + break + + if codeflash_assignment_found: + # Find matching test cases by looking for this test function name in the test results + matching_original_times = [] + matching_optimized_times = [] + + for invocation_id, runtimes in original_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_original_times.extend(runtimes) + + for invocation_id, runtimes in optimized_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_optimized_times.extend(runtimes) + + if matching_original_times and matching_optimized_times: + original_time = min(matching_original_times) + optimized_time = min(matching_optimized_times) + + # Create the runtime comment + comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" + + # Add comment to the trailing whitespace + new_trailing_whitespace = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment(comment_text), + newline=updated_node.trailing_whitespace.newline, + ) + + return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) + + return updated_node + + # Process each generated test + modified_tests = [] + for test in generated_tests.generated_tests: + try: + # Parse the test source code + tree = cst.parse_module(test.generated_original_test_source) + + # Transform the tree to add runtime comments + transformer = RuntimeCommentTransformer() + modified_tree = tree.visit(transformer) + + # Convert back to source code + modified_source = modified_tree.code + + # Create a new GeneratedTests object with the modified source + modified_test = GeneratedTests( + generated_original_test_source=modified_source, + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + modified_tests.append(modified_test) + except Exception as e: + # If parsing fails, keep the original test + logger.debug(f"Failed to add runtime comments to test: {e}") + modified_tests.append(test) + + return GeneratedTestsList(generated_tests=modified_tests) diff --git a/tests/test_remove_functions_from_generated_tests.py b/tests/test_remove_functions_from_generated_tests.py index dc2a14468..c6fd9a7aa 100644 --- a/tests/test_remove_functions_from_generated_tests.py +++ b/tests/test_remove_functions_from_generated_tests.py @@ -1,8 +1,7 @@ from pathlib import Path import pytest - -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests +from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests from codeflash.models.models import GeneratedTests, GeneratedTestsList From bdb5ca6572d595fd53bf8604bc8428bdfe8dacc2 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 19:51:33 -0700 Subject: [PATCH 2/9] Refactor --- codeflash/code_utils/edit_generated_tests.py | 114 ++++++++++++- codeflash/code_utils/time_utils.py | 39 +++++ codeflash/optimization/function_optimizer.py | 158 +------------------ 3 files changed, 157 insertions(+), 154 deletions(-) diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 25eb58965..94ac411f2 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -1,6 +1,10 @@ import re -from codeflash.models.models import GeneratedTestsList +import libcst as cst + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.time_utils import format_time +from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults def remove_functions_from_generated_tests( @@ -26,3 +30,111 @@ def remove_functions_from_generated_tests( new_generated_tests.append(generated_test) return GeneratedTestsList(generated_tests=new_generated_tests) + + +def add_runtime_comments_to_generated_tests( + generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults +) -> GeneratedTestsList: + """Add runtime performance comments to function calls in generated tests.""" + # Create dictionaries for fast lookup of runtime data + original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() + optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() + + class RuntimeCommentTransformer(cst.CSTTransformer): + def __init__(self): + self.in_test_function = False + self.current_test_name = None + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value.startswith("test_"): + self.in_test_function = True + self.current_test_name = node.name.value + else: + self.in_test_function = False + self.current_test_name = None + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if original_node.name.value.startswith("test_"): + self.in_test_function = False + self.current_test_name = None + return updated_node + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> cst.SimpleStatementLine: + if not self.in_test_function or not self.current_test_name: + return updated_node + + # Look for assignment statements that assign to codeflash_output + # Handle both single statements and multiple statements on one line + codeflash_assignment_found = False + for stmt in updated_node.body: + if isinstance(stmt, cst.Assign): + if ( + len(stmt.targets) == 1 + and isinstance(stmt.targets[0].target, cst.Name) + and stmt.targets[0].target.value == "codeflash_output" + ): + codeflash_assignment_found = True + break + + if codeflash_assignment_found: + # Find matching test cases by looking for this test function name in the test results + matching_original_times = [] + matching_optimized_times = [] + + for invocation_id, runtimes in original_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_original_times.extend(runtimes) + + for invocation_id, runtimes in optimized_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_optimized_times.extend(runtimes) + + if matching_original_times and matching_optimized_times: + original_time = min(matching_original_times) + optimized_time = min(matching_optimized_times) + + # Create the runtime comment + comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" + + # Add comment to the trailing whitespace + new_trailing_whitespace = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment(comment_text), + newline=updated_node.trailing_whitespace.newline, + ) + + return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) + + return updated_node + + # Process each generated test + modified_tests = [] + for test in generated_tests.generated_tests: + try: + # Parse the test source code + tree = cst.parse_module(test.generated_original_test_source) + + # Transform the tree to add runtime comments + transformer = RuntimeCommentTransformer() + modified_tree = tree.visit(transformer) + + # Convert back to source code + modified_source = modified_tree.code + + # Create a new GeneratedTests object with the modified source + modified_test = GeneratedTests( + generated_original_test_source=modified_source, + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + modified_tests.append(modified_test) + except Exception as e: + # If parsing fails, keep the original test + logger.debug(f"Failed to add runtime comments to test: {e}") + modified_tests.append(test) + + return GeneratedTestsList(generated_tests=modified_tests) diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index aaf74fc93..936b1b0af 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -49,3 +49,42 @@ def humanize_runtime(time_in_ns: int) -> str: runtime_human = runtime_human_parts[0] return f"{runtime_human} {units}" + + +def format_time(nanoseconds: int) -> str: + """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" + + def count_significant_digits(num: int) -> int: + """Count significant digits in an integer.""" + return len(str(abs(num))) + + def format_with_precision(value: float, unit: str) -> str: + """Format a value with 3 significant digits precision.""" + if value >= 100: + return f"{value:.0f}{unit}" + if value >= 10: + return f"{value:.1f}{unit}" + return f"{value:.2f}{unit}" + + if nanoseconds < 1_000: + return f"{nanoseconds}ns" + if nanoseconds < 1_000_000: + # Convert to microseconds + microseconds_int = nanoseconds // 1_000 + if count_significant_digits(microseconds_int) >= 3: + return f"{microseconds_int}μs" + microseconds_float = nanoseconds / 1_000 + return format_with_precision(microseconds_float, "μs") + if nanoseconds < 1_000_000_000: + # Convert to milliseconds + milliseconds_int = nanoseconds // 1_000_000 + if count_significant_digits(milliseconds_int) >= 3: + return f"{milliseconds_int}ms" + milliseconds_float = nanoseconds / 1_000_000 + return format_with_precision(milliseconds_float, "ms") + # Convert to seconds + seconds_int = nanoseconds // 1_000_000_000 + if count_significant_digits(seconds_int) >= 3: + return f"{seconds_int}s" + seconds_float = nanoseconds / 1_000_000_000 + return format_with_precision(seconds_float, "s") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 82e9d2c52..a3cd11b3e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -36,7 +36,10 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests +from codeflash.code_utils.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + remove_functions_from_generated_tests, +) from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports @@ -319,7 +322,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove ) # Add runtime comments to generated tests before creating the PR - generated_tests = self.add_runtime_comments_to_generated_tests( + generated_tests = add_runtime_comments_to_generated_tests( generated_tests, original_code_baseline.benchmarking_test_results, best_optimization.winning_benchmarking_test_results, @@ -1270,154 +1273,3 @@ def cleanup_generated_files(self) -> None: cleanup_paths(paths_to_cleanup) if hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir.cleanup() - - def add_runtime_comments_to_generated_tests( - self, - generated_tests: GeneratedTestsList, - original_test_results: TestResults, - optimized_test_results: TestResults, - ) -> GeneratedTestsList: - """Add runtime performance comments to function calls in generated tests.""" - - def format_time(nanoseconds: int) -> str: - """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" - - def count_significant_digits(num: int) -> int: - """Count significant digits in an integer.""" - return len(str(abs(num))) - - def format_with_precision(value: float, unit: str) -> str: - """Format a value with 3 significant digits precision.""" - if value >= 100: - return f"{value:.0f}{unit}" - if value >= 10: - return f"{value:.1f}{unit}" - return f"{value:.2f}{unit}" - - if nanoseconds < 1_000: - return f"{nanoseconds}ns" - if nanoseconds < 1_000_000: - # Convert to microseconds - microseconds_int = nanoseconds // 1_000 - if count_significant_digits(microseconds_int) >= 3: - return f"{microseconds_int}μs" - microseconds_float = nanoseconds / 1_000 - return format_with_precision(microseconds_float, "μs") - if nanoseconds < 1_000_000_000: - # Convert to milliseconds - milliseconds_int = nanoseconds // 1_000_000 - if count_significant_digits(milliseconds_int) >= 3: - return f"{milliseconds_int}ms" - milliseconds_float = nanoseconds / 1_000_000 - return format_with_precision(milliseconds_float, "ms") - # Convert to seconds - seconds_int = nanoseconds // 1_000_000_000 - if count_significant_digits(seconds_int) >= 3: - return f"{seconds_int}s" - seconds_float = nanoseconds / 1_000_000_000 - return format_with_precision(seconds_float, "s") - - # Create dictionaries for fast lookup of runtime data - original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() - optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() - - class RuntimeCommentTransformer(cst.CSTTransformer): - def __init__(self): - self.in_test_function = False - self.current_test_name = None - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - if node.name.value.startswith("test_"): - self.in_test_function = True - self.current_test_name = node.name.value - else: - self.in_test_function = False - self.current_test_name = None - - def leave_FunctionDef( - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef - ) -> cst.FunctionDef: - if original_node.name.value.startswith("test_"): - self.in_test_function = False - self.current_test_name = None - return updated_node - - def leave_SimpleStatementLine( - self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine - ) -> cst.SimpleStatementLine: - if not self.in_test_function or not self.current_test_name: - return updated_node - - # Look for assignment statements that assign to codeflash_output - # Handle both single statements and multiple statements on one line - codeflash_assignment_found = False - for stmt in updated_node.body: - if isinstance(stmt, cst.Assign): - if ( - len(stmt.targets) == 1 - and isinstance(stmt.targets[0].target, cst.Name) - and stmt.targets[0].target.value == "codeflash_output" - ): - codeflash_assignment_found = True - break - - if codeflash_assignment_found: - # Find matching test cases by looking for this test function name in the test results - matching_original_times = [] - matching_optimized_times = [] - - for invocation_id, runtimes in original_runtime_by_test.items(): - if invocation_id.test_function_name == self.current_test_name: - matching_original_times.extend(runtimes) - - for invocation_id, runtimes in optimized_runtime_by_test.items(): - if invocation_id.test_function_name == self.current_test_name: - matching_optimized_times.extend(runtimes) - - if matching_original_times and matching_optimized_times: - original_time = min(matching_original_times) - optimized_time = min(matching_optimized_times) - - # Create the runtime comment - comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" - - # Add comment to the trailing whitespace - new_trailing_whitespace = cst.TrailingWhitespace( - whitespace=cst.SimpleWhitespace(" "), - comment=cst.Comment(comment_text), - newline=updated_node.trailing_whitespace.newline, - ) - - return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) - - return updated_node - - # Process each generated test - modified_tests = [] - for test in generated_tests.generated_tests: - try: - # Parse the test source code - tree = cst.parse_module(test.generated_original_test_source) - - # Transform the tree to add runtime comments - transformer = RuntimeCommentTransformer() - modified_tree = tree.visit(transformer) - - # Convert back to source code - modified_source = modified_tree.code - - # Create a new GeneratedTests object with the modified source - modified_test = GeneratedTests( - generated_original_test_source=modified_source, - instrumented_behavior_test_source=test.instrumented_behavior_test_source, - instrumented_perf_test_source=test.instrumented_perf_test_source, - behavior_file_path=test.behavior_file_path, - perf_file_path=test.perf_file_path, - ) - modified_tests.append(modified_test) - except Exception as e: - # If parsing fails, keep the original test - logger.debug(f"Failed to add runtime comments to test: {e}") - modified_tests.append(test) - - return GeneratedTestsList(generated_tests=modified_tests) From d0de40b84f82362404607ddeca8810a0872ac814 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 19:55:32 -0700 Subject: [PATCH 3/9] add missing test file --- tests/test_add_runtime_comments.py | 495 +++++++++++++++++++++++++++++ 1 file changed, 495 insertions(+) create mode 100644 tests/test_add_runtime_comments.py diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py new file mode 100644 index 000000000..d392d140e --- /dev/null +++ b/tests/test_add_runtime_comments.py @@ -0,0 +1,495 @@ +"""Tests for the add_runtime_comments_to_generated_tests functionality.""" + +from pathlib import Path +from unittest.mock import Mock + +from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests +from codeflash.models.models import ( + FunctionTestInvocation, + GeneratedTests, + GeneratedTestsList, + InvocationId, + TestResults, + TestType, + VerificationType, +) +from codeflash.optimization.function_optimizer import FunctionOptimizer + + +class TestAddRuntimeComments: + """Test cases for add_runtime_comments_to_generated_tests method.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create a mock FunctionOptimizer with minimal required attributes + self.optimizer = Mock(spec=FunctionOptimizer) + # We need to use the real implementation of the method + self.optimizer.add_runtime_comments_to_generated_tests = add_runtime_comments_to_generated_tests.__get__( + self.optimizer, FunctionOptimizer + ) + + def create_test_invocation( + self, test_function_name: str, runtime: int, loop_index: int = 1, iteration_id: str = "1", did_pass: bool = True + ) -> FunctionTestInvocation: + """Helper to create test invocation objects.""" + return FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path="test_module", + test_class_name=None, + test_function_name=test_function_name, + function_getting_tested="test_function", + iteration_id=iteration_id, + ), + file_name=Path("test.py"), + did_pass=did_pass, + runtime=runtime, + test_framework="pytest", + test_type=TestType.GENERATED_REGRESSION, + return_value=None, + timed_out=False, + verification_type=VerificationType.FUNCTION_CALL, + ) + + def test_basic_runtime_comment_addition(self): + """Test basic functionality of adding runtime comments.""" + # Create test source code + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add test invocations with different runtimes + original_invocation = self.create_test_invocation("test_bubble_sort", 500_000) # 500μs + optimized_invocation = self.create_test_invocation("test_bubble_sort", 300_000) # 300μs + + original_test_results.add(original_invocation) + optimized_test_results.add(optimized_invocation) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 500μs -> 300μs" in modified_source + assert "codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source + + def test_multiple_test_functions(self): + """Test handling multiple test functions in the same file.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] + +def test_quick_sort(): + codeflash_output = quick_sort([5, 2, 8]) + assert codeflash_output == [2, 5, 8] + +def helper_function(): + return "not a test" +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results for both functions + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add test invocations for both test functions + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + modified_source = result.generated_tests[0].generated_original_test_source + + # Check that comments were added to both test functions + assert "# 500μs -> 300μs" in modified_source + assert "# 800μs -> 600μs" in modified_source + # Helper function should not have comments + assert ( + "helper_function():" in modified_source + and "# " not in modified_source.split("helper_function():")[1].split("\n")[0] + ) + + def test_different_time_formats(self): + """Test that different time ranges are formatted correctly with new precision rules.""" + test_cases = [ + (999, 500, "999ns -> 500ns"), # nanoseconds + (25_000, 18_000, "25.0μs -> 18.0μs"), # microseconds with precision + (500_000, 300_000, "500μs -> 300μs"), # microseconds full integers + (1_500_000, 800_000, "1.50ms -> 800μs"), # milliseconds with precision + (365_000_000, 290_000_000, "365ms -> 290ms"), # milliseconds full integers + (2_000_000_000, 1_500_000_000, "2.00s -> 1.50s"), # seconds with precision + ] + + for original_time, optimized_time, expected_comment in test_cases: + test_source = """def test_function(): + codeflash_output = some_function() + assert codeflash_output is not None +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_function", original_time)) + optimized_test_results.add(self.create_test_invocation("test_function", optimized_time)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + modified_source = result.generated_tests[0].generated_original_test_source + assert f"# {expected_comment}" in modified_source + + def test_missing_test_results(self): + """Test behavior when test results are missing for a test function.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create empty test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that no comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_partial_test_results(self): + """Test behavior when only one set of test results is available.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results with only original data + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + # No optimized results + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that no comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_multiple_runtimes_uses_minimum(self): + """Test that when multiple runtimes exist, the minimum is used.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results with multiple loop iterations + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add multiple runs with different runtimes + original_test_results.add(self.create_test_invocation("test_bubble_sort", 600_000, loop_index=1)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, loop_index=2)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 550_000, loop_index=3)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 350_000, loop_index=1)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that minimum times were used (500μs -> 300μs) + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 500μs -> 300μs" in modified_source + + def test_no_codeflash_output_assignment(self): + """Test behavior when test doesn't have codeflash_output assignment.""" + test_source = """def test_bubble_sort(): + result = bubble_sort([3, 1, 2]) + assert result == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that no comments were added (no codeflash_output assignment) + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_invalid_python_code_handling(self): + """Test behavior when test source code is invalid Python.""" + test_source = """def test_bubble_sort(: + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" # Invalid syntax: extra colon + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality - should handle parse error gracefully + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that original test is preserved when parsing fails + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged due to parse error + + def test_multiple_generated_tests(self): + """Test handling multiple generated test objects.""" + test_source_1 = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + test_source_2 = """def test_quick_sort(): + codeflash_output = quick_sort([5, 2, 8]) + assert codeflash_output == [2, 5, 8] +""" + + generated_test_1 = GeneratedTests( + generated_original_test_source=test_source_1, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior_1.py"), + perf_file_path=Path("test_perf_1.py"), + ) + + generated_test_2 = GeneratedTests( + generated_original_test_source=test_source_2, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior_2.py"), + perf_file_path=Path("test_perf_2.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test_1, generated_test_2]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that comments were added to both test files + modified_source_1 = result.generated_tests[0].generated_original_test_source + modified_source_2 = result.generated_tests[1].generated_original_test_source + + assert "# 500μs -> 300μs" in modified_source_1 + assert "# 800μs -> 600μs" in modified_source_2 + + def test_preserved_test_attributes(self): + """Test that other test attributes are preserved during modification.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + original_behavior_source = "behavior test source" + original_perf_source = "perf test source" + original_behavior_path = Path("test_behavior.py") + original_perf_path = Path("test_perf.py") + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source=original_behavior_source, + instrumented_perf_test_source=original_perf_source, + behavior_file_path=original_behavior_path, + perf_file_path=original_perf_path, + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that other attributes are preserved + modified_test = result.generated_tests[0] + assert modified_test.instrumented_behavior_test_source == original_behavior_source + assert modified_test.instrumented_perf_test_source == original_perf_source + assert modified_test.behavior_file_path == original_behavior_path + assert modified_test.perf_file_path == original_perf_path + + # Check that only the generated_original_test_source was modified + assert "# 500μs -> 300μs" in modified_test.generated_original_test_source + + def test_multistatement_line_handling(self): + """Test that runtime comments work correctly with multiple statements on one line.""" + test_source = """def test_mutation_of_input(): + # Test that the input list is mutated in-place and returned + arr = [3, 1, 2] + codeflash_output = sorter(arr); result = codeflash_output + assert result == [1, 2, 3] + assert arr == [1, 2, 3] # Input should be mutated +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000)) # 19μs + optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs + + # Test the functionality + result = self.optimizer.add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + # Check that comments were added to the correct line + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 19.0μs -> 14.0μs" in modified_source + + # Verify the comment is on the line with codeflash_output assignment + lines = modified_source.split("\n") + codeflash_line = None + for line in lines: + if "codeflash_output = sorter(arr)" in line: + codeflash_line = line + break + + assert codeflash_line is not None, "Could not find codeflash_output assignment line" + assert "# 19.0μs -> 14.0μs" in codeflash_line, f"Comment not found in the correct line: {codeflash_line}" From 291ac2c294f56090b61f4d0cfe1ce97f30e593b6 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 20:01:09 -0700 Subject: [PATCH 4/9] fix test ruff reformat and fix linting --- codeflash/benchmarking/codeflash_trace.py | 4 ++ .../instrument_codeflash_trace.py | 2 + codeflash/benchmarking/plugin/plugin.py | 4 ++ codeflash/benchmarking/replay_test.py | 4 ++ codeflash/benchmarking/utils.py | 2 + codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/cli_cmds/logging_config.py | 4 +- codeflash/code_utils/checkpoint.py | 4 +- codeflash/code_utils/line_profile_utils.py | 3 ++ codeflash/context/code_context_extractor.py | 19 ++++--- .../context/unused_definition_remover.py | 3 ++ codeflash/picklepatch/pickle_patcher.py | 16 ++++++ codeflash/picklepatch/pickle_placeholder.py | 1 + codeflash/tracing/profile_stats.py | 2 +- .../parse_line_profile_test_output.py | 2 +- tests/test_add_runtime_comments.py | 53 ++++--------------- 16 files changed, 71 insertions(+), 54 deletions(-) diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 20743fd56..249acdeb3 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None: """Set up the database connection for direct writing. Args: + ---- trace_path: Path to the trace database file """ @@ -52,6 +53,7 @@ def write_function_timings(self) -> None: """Write function call data directly to the database. Args: + ---- data: List of function call data tuples to write """ @@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable: """Use as a decorator to trace function execution. Args: + ---- func: The function to be decorated Returns: + ------- The wrapped function """ diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index 761e91f71..04b12018a 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct """Add codeflash_trace to a function. Args: + ---- code: The source code as a string functions_to_optimize: List of FunctionToOptimize instances containing function details Returns: + ------- The modified source code as a string """ diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 45fabef14..6516fba38 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark """Process the trace file and extract timing data for all functions. Args: + ---- trace_path: Path to the trace file Returns: + ------- A nested dictionary where: - Outer keys are module_name.qualified_name (module.class.function) - Inner keys are of type BenchmarkKey @@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: """Extract total benchmark timings from trace files. Args: + ---- trace_path: Path to the trace file Returns: + ------- A dictionary mapping where: - Keys are of type BenchmarkKey - Values are total benchmark timing in milliseconds (with overhead subtracted) diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index c44649632..c2e1889db 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -55,12 +55,14 @@ def create_trace_replay_test_code( """Create a replay test for functions based on trace data. Args: + ---- trace_file: Path to the SQLite database file functions_data: List of dictionaries with function info extracted from DB test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include in the test Returns: + ------- A string containing the test code """ @@ -218,12 +220,14 @@ def generate_replay_test( """Generate multiple replay tests from the traced function calls, grouped by benchmark. Args: + ---- trace_file_path: Path to the SQLite database file output_dir: Directory to write the generated tests (if None, only returns the code) test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include per function Returns: + ------- Dictionary mapping benchmark names to generated test code """ diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py index 5dae99444..db89c4c33 100644 --- a/codeflash/benchmarking/utils.py +++ b/codeflash/benchmarking/utils.py @@ -83,11 +83,13 @@ def process_benchmark_data( """Process benchmark data and generate detailed benchmark information. Args: + ---- replay_performance_gain: The performance gain from replay fto_benchmark_timings: Function to optimize benchmark timings total_benchmark_timings: Total benchmark timings Returns: + ------- ProcessedBenchmarkInfo containing processed benchmark details """ diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index bfe600fa4..b8a038a82 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -34,7 +34,7 @@ from argparse import Namespace CODEFLASH_LOGO: str = ( - f"{LF}" # noqa: ISC003 + f"{LF}" r" _ ___ _ _ " + f"{LF}" r" | | / __)| | | | " + f"{LF}" r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}" diff --git a/codeflash/cli_cmds/logging_config.py b/codeflash/cli_cmds/logging_config.py index 8bd4a48d9..e546836fc 100644 --- a/codeflash/cli_cmds/logging_config.py +++ b/codeflash/cli_cmds/logging_config.py @@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: ], force=True, ) - logging.info("Verbose DEBUG logging enabled") # noqa: LOG015 + logging.info("Verbose DEBUG logging enabled") else: - logging.info("Logging level set to INFO") # noqa: LOG015 + logging.info("Logging level set to INFO") console.rule() diff --git a/codeflash/code_utils/checkpoint.py b/codeflash/code_utils/checkpoint.py index 8a333c3fe..4c69ecc58 100644 --- a/codeflash/code_utils/checkpoint.py +++ b/codeflash/code_utils/checkpoint.py @@ -47,6 +47,7 @@ def add_function_to_checkpoint( """Add a function to the checkpoint after it has been processed. Args: + ---- function_fully_qualified_name: The fully qualified name of the function status: Status of optimization (e.g., "optimized", "failed", "skipped") additional_info: Any additional information to store about the function @@ -104,7 +105,8 @@ def cleanup(self) -> None: def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]: """Get information about all processed functions, regardless of status. - Returns: + Returns + ------- Dictionary mapping function names to their processing information """ diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 935e30356..498571578 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None: """Initialize the transformer. Args: + ---- qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). decorator_name: The name of the decorator to add. @@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str, """Add a decorator to a function with the exact qualified name in the source code. Args: + ---- module: The Python source code as a CST module. qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). decorator_name: The name of the decorator to add. Returns: + ------- The modified CST module. """ diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index ba8929343..7910bcfe5 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -3,17 +3,17 @@ import os from collections import defaultdict from itertools import chain -from pathlib import Path # noqa: TC003 +from pathlib import Path from typing import TYPE_CHECKING import libcst as cst -from libcst import CSTNode # noqa: TC002 +from libcst import CSTNode from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names -from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, @@ -150,6 +150,7 @@ def extract_code_string_context_from_files( imports, and combines them. Args: + ---- helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project @@ -157,6 +158,7 @@ def extract_code_string_context_from_files( code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) Returns: + ------- CodeString containing the extracted code context with necessary imports """ # noqa: D205 @@ -257,6 +259,7 @@ def extract_code_markdown_context_from_files( imports, and combines them into a structured markdown format. Args: + ---- helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project @@ -264,6 +267,7 @@ def extract_code_markdown_context_from_files( code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) Returns: + ------- CodeStringsMarkdown containing the extracted code context with necessary imports, formatted for inclusion in markdown @@ -502,7 +506,8 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -586,7 +591,8 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node for read-only context. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -690,7 +696,8 @@ def prune_cst_for_testgen_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node for testgen context. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 86835e128..53e249495 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -311,10 +311,12 @@ def remove_unused_definitions_recursively( # noqa: PLR0911 """Recursively filter the node to remove unused definitions. Args: + ---- node: The CST node to process definitions: Dictionary of definition info Returns: + ------- (filtered_node, used_by_function): filtered_node: The modified CST node or None if it should be removed used_by_function: True if this node or any child is used by qualified functions @@ -450,6 +452,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na If a class is referenced by a qualified function, we keep the entire class. Args: + ---- code: The code to process qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' diff --git a/codeflash/picklepatch/pickle_patcher.py b/codeflash/picklepatch/pickle_patcher.py index 0e08756ab..3f3236f76 100644 --- a/codeflash/picklepatch/pickle_patcher.py +++ b/codeflash/picklepatch/pickle_patcher.py @@ -30,12 +30,14 @@ def dumps(obj: object, protocol: int | None = None, max_depth: int = 100, **kwar """Safely pickle an object, replacing unpicklable parts with placeholders. Args: + ---- obj: The object to pickle protocol: The pickle protocol version to use max_depth: Maximum recursion depth **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -46,9 +48,11 @@ def loads(pickled_data: bytes) -> object: """Unpickle data that may contain placeholders. Args: + ---- pickled_data: Pickled data with possible placeholders Returns: + ------- The unpickled object with placeholders for unpicklable parts """ @@ -59,11 +63,13 @@ def _create_placeholder(obj: object, error_msg: str, path: list[str]) -> PickleP """Create a placeholder for an unpicklable object. Args: + ---- obj: The original unpicklable object error_msg: Error message explaining why it couldn't be pickled path: Path to this object in the object graph Returns: + ------- PicklePlaceholder: A placeholder object """ @@ -91,12 +97,14 @@ def _pickle( """Try to pickle an object using pickle first, then dill. If both fail, create a placeholder. Args: + ---- obj: The object to pickle path: Path to this object in the object graph protocol: The pickle protocol version to use **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- tuple: (success, result) where success is a boolean and result is either: - Pickled bytes if successful - Error message if not successful @@ -123,6 +131,7 @@ def _recursive_pickle( # noqa: PLR0911 """Recursively try to pickle an object, replacing unpicklable parts with placeholders. Args: + ---- obj: The object to pickle max_depth: Maximum recursion depth path: Current path in the object graph @@ -130,6 +139,7 @@ def _recursive_pickle( # noqa: PLR0911 **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -185,6 +195,7 @@ def _handle_dict( """Handle pickling for dictionary objects. Args: + ---- obj_dict: The dictionary to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -193,6 +204,7 @@ def _handle_dict( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -249,6 +261,7 @@ def _handle_sequence( """Handle pickling for sequence types (list, tuple, set). Args: + ---- obj_seq: The sequence to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -257,6 +270,7 @@ def _handle_sequence( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -305,6 +319,7 @@ def _handle_object( """Handle pickling for custom objects with __dict__. Args: + ---- obj: The object to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -313,6 +328,7 @@ def _handle_object( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ diff --git a/codeflash/picklepatch/pickle_placeholder.py b/codeflash/picklepatch/pickle_placeholder.py index 50e9c5aa3..4268a9146 100644 --- a/codeflash/picklepatch/pickle_placeholder.py +++ b/codeflash/picklepatch/pickle_placeholder.py @@ -18,6 +18,7 @@ def __init__(self, obj_type: str, obj_str: str, error_msg: str, path: list[str] """Initialize a placeholder for an unpicklable object. Args: + ---- obj_type (str): The type name of the original object obj_str (str): String representation of the original object error_msg (str): The error message that occurred during pickling diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index c2ed7cb49..4ec19637a 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -55,7 +55,7 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 print(indent, self.total_calls, "function calls", end=" ", file=self.stream) if self.total_calls != self.prim_calls: - print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) # noqa: UP031 + print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) print(file=self.stream) diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 1877c0654..33a109fae 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -46,7 +46,7 @@ def show_func( perhit_disp = "%5.1f" % (float(time) * scalar / nhits) if len(perhit_disp) > default_column_sizes["perhit"]: perhit_disp = "%5.1g" % (float(time) * scalar / nhits) - nhits_disp = "%d" % nhits # noqa: UP031 + nhits_disp = "%d" % nhits if len(nhits_disp) > default_column_sizes["hits"]: nhits_disp = f"{nhits:g}" display[lineno] = (nhits_disp, time_disp, perhit_disp, percent) diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index d392d140e..51c1ef052 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -1,7 +1,6 @@ """Tests for the add_runtime_comments_to_generated_tests functionality.""" from pathlib import Path -from unittest.mock import Mock from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests from codeflash.models.models import ( @@ -13,21 +12,11 @@ TestType, VerificationType, ) -from codeflash.optimization.function_optimizer import FunctionOptimizer class TestAddRuntimeComments: """Test cases for add_runtime_comments_to_generated_tests method.""" - def setup_method(self): - """Set up test fixtures.""" - # Create a mock FunctionOptimizer with minimal required attributes - self.optimizer = Mock(spec=FunctionOptimizer) - # We need to use the real implementation of the method - self.optimizer.add_runtime_comments_to_generated_tests = add_runtime_comments_to_generated_tests.__get__( - self.optimizer, FunctionOptimizer - ) - def create_test_invocation( self, test_function_name: str, runtime: int, loop_index: int = 1, iteration_id: str = "1", did_pass: bool = True ) -> FunctionTestInvocation: @@ -81,9 +70,7 @@ def test_basic_runtime_comment_addition(self): optimized_test_results.add(optimized_invocation) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -126,9 +113,7 @@ def helper_function(): optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) modified_source = result.generated_tests[0].generated_original_test_source @@ -176,7 +161,7 @@ def test_different_time_formats(self): optimized_test_results.add(self.create_test_invocation("test_function", optimized_time)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( + result = add_runtime_comments_to_generated_tests( generated_tests, original_test_results, optimized_test_results ) @@ -205,9 +190,7 @@ def test_missing_test_results(self): optimized_test_results = TestResults() # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -238,9 +221,7 @@ def test_partial_test_results(self): # No optimized results # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -277,9 +258,7 @@ def test_multiple_runtimes_uses_minimum(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that minimum times were used (500μs -> 300μs) modified_source = result.generated_tests[0].generated_original_test_source @@ -310,9 +289,7 @@ def test_no_codeflash_output_assignment(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that no comments were added (no codeflash_output assignment) modified_source = result.generated_tests[0].generated_original_test_source @@ -343,9 +320,7 @@ def test_invalid_python_code_handling(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) # Test the functionality - should handle parse error gracefully - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that original test is preserved when parsing fails modified_source = result.generated_tests[0].generated_original_test_source @@ -392,9 +367,7 @@ def test_multiple_generated_tests(self): optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that comments were added to both test files modified_source_1 = result.generated_tests[0].generated_original_test_source @@ -433,9 +406,7 @@ def test_preserved_test_attributes(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that other attributes are preserved modified_test = result.generated_tests[0] @@ -475,9 +446,7 @@ def test_multistatement_line_handling(self): optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs # Test the functionality - result = self.optimizer.add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results - ) + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) # Check that comments were added to the correct line modified_source = result.generated_tests[0].generated_original_test_source From b94204015e36c3b02243341efcb7a83dab0efaef Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 20:08:09 -0700 Subject: [PATCH 5/9] rename file name --- mypy_allowlist.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy_allowlist.txt b/mypy_allowlist.txt index 6abaa8894..6a070b606 100644 --- a/mypy_allowlist.txt +++ b/mypy_allowlist.txt @@ -29,7 +29,7 @@ codeflash/code_utils/time_utils.py codeflash/code_utils/env_utils.py codeflash/code_utils/config_consts.py codeflash/code_utils/static_analysis.py -codeflash/code_utils/remove_generated_tests.py +codeflash/code_utils/edit_generated_tests.py codeflash/cli_cmds/console_constants.py codeflash/cli_cmds/logging_config.py codeflash/cli_cmds/__init__.py From bde6e06ad52a221d8ff755c13374b1825dbd2b36 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 5 Jun 2025 20:10:41 -0700 Subject: [PATCH 6/9] pre-commit run --- codeflash/code_utils/edit_generated_tests.py | 21 ++++++------ codeflash/code_utils/time_utils.py | 36 ++++++++++++-------- codeflash/context/code_context_extractor.py | 2 +- pyproject.toml | 3 +- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 94ac411f2..d44e5d885 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -41,7 +41,7 @@ def add_runtime_comments_to_generated_tests( optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() class RuntimeCommentTransformer(cst.CSTTransformer): - def __init__(self): + def __init__(self) -> None: self.in_test_function = False self.current_test_name = None @@ -60,7 +60,9 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu return updated_node def leave_SimpleStatementLine( - self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + self, + original_node: cst.SimpleStatementLine, # noqa: ARG002 + updated_node: cst.SimpleStatementLine, ) -> cst.SimpleStatementLine: if not self.in_test_function or not self.current_test_name: return updated_node @@ -69,14 +71,13 @@ def leave_SimpleStatementLine( # Handle both single statements and multiple statements on one line codeflash_assignment_found = False for stmt in updated_node.body: - if isinstance(stmt, cst.Assign): - if ( - len(stmt.targets) == 1 - and isinstance(stmt.targets[0].target, cst.Name) - and stmt.targets[0].target.value == "codeflash_output" - ): - codeflash_assignment_found = True - break + if isinstance(stmt, cst.Assign) and ( + len(stmt.targets) == 1 + and isinstance(stmt.targets[0].target, cst.Name) + and stmt.targets[0].target.value == "codeflash_output" + ): + codeflash_assignment_found = True + break if codeflash_assignment_found: # Find matching test cases by looking for this test function name in the test results diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 936b1b0af..7528441a9 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -66,25 +66,31 @@ def format_with_precision(value: float, unit: str) -> str: return f"{value:.1f}{unit}" return f"{value:.2f}{unit}" + result = "" if nanoseconds < 1_000: - return f"{nanoseconds}ns" - if nanoseconds < 1_000_000: + result = f"{nanoseconds}ns" + elif nanoseconds < 1_000_000: # Convert to microseconds microseconds_int = nanoseconds // 1_000 if count_significant_digits(microseconds_int) >= 3: - return f"{microseconds_int}μs" - microseconds_float = nanoseconds / 1_000 - return format_with_precision(microseconds_float, "μs") - if nanoseconds < 1_000_000_000: + result = f"{microseconds_int}μs" + else: + microseconds_float = nanoseconds / 1_000 + result = format_with_precision(microseconds_float, "μs") + elif nanoseconds < 1_000_000_000: # Convert to milliseconds milliseconds_int = nanoseconds // 1_000_000 if count_significant_digits(milliseconds_int) >= 3: - return f"{milliseconds_int}ms" - milliseconds_float = nanoseconds / 1_000_000 - return format_with_precision(milliseconds_float, "ms") - # Convert to seconds - seconds_int = nanoseconds // 1_000_000_000 - if count_significant_digits(seconds_int) >= 3: - return f"{seconds_int}s" - seconds_float = nanoseconds / 1_000_000_000 - return format_with_precision(seconds_float, "s") + result = f"{milliseconds_int}ms" + else: + milliseconds_float = nanoseconds / 1_000_000 + result = format_with_precision(milliseconds_float, "ms") + else: + # Convert to seconds + seconds_int = nanoseconds // 1_000_000_000 + if count_significant_digits(seconds_int) >= 3: + result = f"{seconds_int}s" + else: + seconds_float = nanoseconds / 1_000_000_000 + result = format_with_precision(seconds_float, "s") + return result diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 7910bcfe5..c05fba207 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -386,7 +386,7 @@ def get_function_to_optimize_as_function_source( source_code=name.get_line_code(), jedi_definition=name, ) - except Exception as e: # noqa: PERF203 + except Exception as e: logger.exception(f"Error while getting function source: {e}") continue raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index c3e48f889..8f96fb6c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,7 +191,8 @@ ignore = [ "T201", "PGH004", "S301", - "D104" + "D104", + "PERF203" ] [tool.ruff.lint.flake8-type-checking] From 4fe83b691d6dcc98fa90a0bac02c2007d16aaf68 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 5 Jun 2025 20:16:06 -0700 Subject: [PATCH 7/9] mypy fixes --- codeflash/code_utils/edit_generated_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index d44e5d885..4e6e31072 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -43,7 +43,7 @@ def add_runtime_comments_to_generated_tests( class RuntimeCommentTransformer(cst.CSTTransformer): def __init__(self) -> None: self.in_test_function = False - self.current_test_name = None + self.current_test_name: str | None = None def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if node.name.value.startswith("test_"): From e6272e8fbf55682df6a9f4fbd5db2f843d317004 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 5 Jun 2025 20:29:03 -0700 Subject: [PATCH 8/9] more pre-commit fixes --- codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/context/code_context_extractor.py | 7 ++++--- codeflash/tracing/profile_stats.py | 2 +- codeflash/verification/parse_line_profile_test_output.py | 2 +- pyproject.toml | 3 ++- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index b8a038a82..bfe600fa4 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -34,7 +34,7 @@ from argparse import Namespace CODEFLASH_LOGO: str = ( - f"{LF}" + f"{LF}" # noqa: ISC003 r" _ ___ _ _ " + f"{LF}" r" | | / __)| | | | " + f"{LF}" r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}" diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index c05fba207..934d3053b 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -3,17 +3,15 @@ import os from collections import defaultdict from itertools import chain -from pathlib import Path from typing import TYPE_CHECKING import libcst as cst -from libcst import CSTNode from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names -from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 from codeflash.models.models import ( CodeContextType, CodeOptimizationContext, @@ -24,7 +22,10 @@ from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: + from pathlib import Path + from jedi.api.classes import Name + from libcst import CSTNode def get_code_optimization_context( diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index 4ec19637a..8e2fc5e28 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -55,7 +55,7 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 print(indent, self.total_calls, "function calls", end=" ", file=self.stream) if self.total_calls != self.prim_calls: - print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) + print(f"({self.prim_calls:d} primitive calls)", end=" ", file=self.stream) time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) print(file=self.stream) diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py index 33a109fae..1877c0654 100644 --- a/codeflash/verification/parse_line_profile_test_output.py +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -46,7 +46,7 @@ def show_func( perhit_disp = "%5.1f" % (float(time) * scalar / nhits) if len(perhit_disp) > default_column_sizes["perhit"]: perhit_disp = "%5.1g" % (float(time) * scalar / nhits) - nhits_disp = "%d" % nhits + nhits_disp = "%d" % nhits # noqa: UP031 if len(nhits_disp) > default_column_sizes["hits"]: nhits_disp = f"{nhits:g}" display[lineno] = (nhits_disp, time_disp, perhit_disp, percent) diff --git a/pyproject.toml b/pyproject.toml index 8f96fb6c9..cb8f2c7d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,8 @@ ignore = [ "PGH004", "S301", "D104", - "PERF203" + "PERF203", + "LOG015" ] [tool.ruff.lint.flake8-type-checking] From 2bffc7ff57bee01c52afd00ce1d74656e8859e06 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 5 Jun 2025 22:52:57 -0700 Subject: [PATCH 9/9] Update codeflash/code_utils/time_utils.py Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- codeflash/code_utils/time_utils.py | 68 +++++++++++++----------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 7528441a9..4e43e7239 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -53,44 +53,36 @@ def humanize_runtime(time_in_ns: int) -> str: def format_time(nanoseconds: int) -> str: """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" - - def count_significant_digits(num: int) -> int: - """Count significant digits in an integer.""" - return len(str(abs(num))) - - def format_with_precision(value: float, unit: str) -> str: - """Format a value with 3 significant digits precision.""" - if value >= 100: - return f"{value:.0f}{unit}" - if value >= 10: - return f"{value:.1f}{unit}" - return f"{value:.2f}{unit}" - - result = "" + # Inlined significant digit check: >= 3 digits if value >= 100 if nanoseconds < 1_000: - result = f"{nanoseconds}ns" - elif nanoseconds < 1_000_000: - # Convert to microseconds + return f"{nanoseconds}ns" + if nanoseconds < 1_000_000: microseconds_int = nanoseconds // 1_000 - if count_significant_digits(microseconds_int) >= 3: - result = f"{microseconds_int}μs" - else: - microseconds_float = nanoseconds / 1_000 - result = format_with_precision(microseconds_float, "μs") - elif nanoseconds < 1_000_000_000: - # Convert to milliseconds + if microseconds_int >= 100: + return f"{microseconds_int}μs" + microseconds = nanoseconds / 1_000 + # Format with precision: 3 significant digits + if microseconds >= 100: + return f"{microseconds:.0f}μs" + if microseconds >= 10: + return f"{microseconds:.1f}μs" + return f"{microseconds:.2f}μs" + if nanoseconds < 1_000_000_000: milliseconds_int = nanoseconds // 1_000_000 - if count_significant_digits(milliseconds_int) >= 3: - result = f"{milliseconds_int}ms" - else: - milliseconds_float = nanoseconds / 1_000_000 - result = format_with_precision(milliseconds_float, "ms") - else: - # Convert to seconds - seconds_int = nanoseconds // 1_000_000_000 - if count_significant_digits(seconds_int) >= 3: - result = f"{seconds_int}s" - else: - seconds_float = nanoseconds / 1_000_000_000 - result = format_with_precision(seconds_float, "s") - return result + if milliseconds_int >= 100: + return f"{milliseconds_int}ms" + milliseconds = nanoseconds / 1_000_000 + if milliseconds >= 100: + return f"{milliseconds:.0f}ms" + if milliseconds >= 10: + return f"{milliseconds:.1f}ms" + return f"{milliseconds:.2f}ms" + seconds_int = nanoseconds // 1_000_000_000 + if seconds_int >= 100: + return f"{seconds_int}s" + seconds = nanoseconds / 1_000_000_000 + if seconds >= 100: + return f"{seconds:.0f}s" + if seconds >= 10: + return f"{seconds:.1f}s" + return f"{seconds:.2f}s"