diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 4e43e7239..89273fe2d 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -53,36 +53,35 @@ def humanize_runtime(time_in_ns: int) -> str: def format_time(nanoseconds: int) -> str: """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" - # Inlined significant digit check: >= 3 digits if value >= 100 + # Define conversion factors and units + if not isinstance(nanoseconds, int): + raise TypeError("Input must be an integer.") + if nanoseconds < 0: + raise ValueError("Input must be a positive integer.") + conversions = [(1_000_000_000, "s"), (1_000_000, "ms"), (1_000, "μs"), (1, "ns")] + + # Handle nanoseconds case directly (no decimal formatting needed) if nanoseconds < 1_000: return f"{nanoseconds}ns" - if nanoseconds < 1_000_000: - microseconds_int = nanoseconds // 1_000 - if microseconds_int >= 100: - return f"{microseconds_int}μs" - microseconds = nanoseconds / 1_000 - # Format with precision: 3 significant digits - if microseconds >= 100: - return f"{microseconds:.0f}μs" - if microseconds >= 10: - return f"{microseconds:.1f}μs" - return f"{microseconds:.2f}μs" - if nanoseconds < 1_000_000_000: - milliseconds_int = nanoseconds // 1_000_000 - if milliseconds_int >= 100: - return f"{milliseconds_int}ms" - milliseconds = nanoseconds / 1_000_000 - if milliseconds >= 100: - return f"{milliseconds:.0f}ms" - if milliseconds >= 10: - return f"{milliseconds:.1f}ms" - return f"{milliseconds:.2f}ms" - seconds_int = nanoseconds // 1_000_000_000 - if seconds_int >= 100: - return f"{seconds_int}s" - seconds = nanoseconds / 1_000_000_000 - if seconds >= 100: - return f"{seconds:.0f}s" - if seconds >= 10: - return f"{seconds:.1f}s" - return f"{seconds:.2f}s" + + # Find appropriate unit + for divisor, unit in conversions: + if nanoseconds >= divisor: + value = nanoseconds / divisor + int_value = nanoseconds // divisor + + # Use integer formatting for values >= 100 + if int_value >= 100: + formatted_value = f"{int_value:.0f}" + # Format with precision for 3 significant digits + elif value >= 100: + formatted_value = f"{value:.0f}" + elif value >= 10: + formatted_value = f"{value:.1f}" + else: + formatted_value = f"{value:.2f}" + + return f"{formatted_value}{unit}" + + # This should never be reached, but included for completeness + return f"{nanoseconds}ns" diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 0de33fb9e..977c72bd3 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -3,14 +3,20 @@ import ast from collections import defaultdict from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path +from typing import TYPE_CHECKING, Optional import libcst as cst from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_replacer import replace_function_definitions_in_module -from codeflash.models.models import CodeOptimizationContext, FunctionSource + +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import CodeOptimizationContext, FunctionSource @dataclass @@ -493,11 +499,12 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None: def revert_unused_helper_functions( - project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] + project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str] ) -> None: """Revert unused helper functions back to their original definitions. Args: + project_root: project_root unused_helpers: List of unused helper functions to revert original_helper_code: Dictionary mapping file paths to their original code @@ -516,9 +523,6 @@ def revert_unused_helper_functions( for file_path, helpers_in_file in unused_helpers_by_file.items(): if file_path in original_helper_code: try: - # Read current file content - current_code = file_path.read_text(encoding="utf8") - # Get original code for this file original_code = original_helper_code[file_path] @@ -557,7 +561,6 @@ def _analyze_imports_in_optimized_code( # Precompute a two-level dict: module_name -> func_name -> [helpers] helpers_by_file_and_func = defaultdict(dict) helpers_by_file = defaultdict(list) # preserved for "import module" - helpers_append = helpers_by_file_and_func.setdefault for helper in code_context.helper_functions: jedi_type = helper.jedi_definition.type if jedi_type != "class": @@ -606,11 +609,12 @@ def _analyze_imports_in_optimized_code( def detect_unused_helper_functions( - function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str + function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str ) -> list[FunctionSource]: """Detect helper functions that are no longer called by the optimized entrypoint function. Args: + function_to_optimize: The function to optimize code_context: The code optimization context containing helper functions optimized_code: The optimized code to analyze @@ -702,8 +706,9 @@ def detect_unused_helper_functions( logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code") logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}") - return unused_helpers + ret_val = unused_helpers except Exception as e: logger.debug(f"Error detecting unused helper functions: {e}") - return [] + ret_val = [] + return ret_val diff --git a/tests/test_humanize_time.py b/tests/test_humanize_time.py index 608c66b10..4021b077e 100644 --- a/tests/test_humanize_time.py +++ b/tests/test_humanize_time.py @@ -1,4 +1,5 @@ -from codeflash.code_utils.time_utils import humanize_runtime +from codeflash.code_utils.time_utils import humanize_runtime, format_time +import pytest def test_humanize_runtime(): @@ -28,3 +29,147 @@ def test_humanize_runtime(): assert humanize_runtime(12345678912345) == "3.43 hours" assert humanize_runtime(98765431298760) == "1.14 days" assert humanize_runtime(197530862597520) == "2.29 days" + + +class TestFormatTime: + """Test cases for the format_time function.""" + + def test_nanoseconds_range(self): + """Test formatting for nanoseconds (< 1,000 ns).""" + assert format_time(0) == "0ns" + assert format_time(1) == "1ns" + assert format_time(500) == "500ns" + assert format_time(999) == "999ns" + + def test_microseconds_range(self): + """Test formatting for microseconds (1,000 ns to 999,999 ns).""" + # Integer microseconds >= 100 + # assert format_time(100_000) == "100μs" + # assert format_time(500_000) == "500μs" + # assert format_time(999_000) == "999μs" + + # Decimal microseconds with varying precision + assert format_time(1_000) == "1.00μs" # 1.0 μs, 2 decimal places + assert format_time(1_500) == "1.50μs" # 1.5 μs, 2 decimal places + assert format_time(9_999) == "10.00μs" # 9.999 μs rounds to 10.00 + assert format_time(10_000) == "10.0μs" # 10.0 μs, 1 decimal place + assert format_time(15_500) == "15.5μs" # 15.5 μs, 1 decimal place + assert format_time(99_900) == "99.9μs" # 99.9 μs, 1 decimal place + + def test_milliseconds_range(self): + """Test formatting for milliseconds (1,000,000 ns to 999,999,999 ns).""" + # Integer milliseconds >= 100 + assert format_time(100_000_000) == "100ms" + assert format_time(500_000_000) == "500ms" + assert format_time(999_000_000) == "999ms" + + # Decimal milliseconds with varying precision + assert format_time(1_000_000) == "1.00ms" # 1.0 ms, 2 decimal places + assert format_time(1_500_000) == "1.50ms" # 1.5 ms, 2 decimal places + assert format_time(9_999_000) == "10.00ms" # 9.999 ms rounds to 10.00 + assert format_time(10_000_000) == "10.0ms" # 10.0 ms, 1 decimal place + assert format_time(15_500_000) == "15.5ms" # 15.5 ms, 1 decimal place + assert format_time(99_900_000) == "99.9ms" # 99.9 ms, 1 decimal place + + def test_seconds_range(self): + """Test formatting for seconds (>= 1,000,000,000 ns).""" + # Integer seconds >= 100 + assert format_time(100_000_000_000) == "100s" + assert format_time(500_000_000_000) == "500s" + assert format_time(999_000_000_000) == "999s" + + # Decimal seconds with varying precision + assert format_time(1_000_000_000) == "1.00s" # 1.0 s, 2 decimal places + assert format_time(1_500_000_000) == "1.50s" # 1.5 s, 2 decimal places + assert format_time(9_999_000_000) == "10.00s" # 9.999 s rounds to 10.00 + assert format_time(10_000_000_000) == "10.0s" # 10.0 s, 1 decimal place + assert format_time(15_500_000_000) == "15.5s" # 15.5 s, 1 decimal place + assert format_time(99_900_000_000) == "99.9s" # 99.9 s, 1 decimal place + + def test_boundary_values(self): + """Test exact boundary values between units.""" + # Boundaries between nanoseconds and microseconds + assert format_time(999) == "999ns" + assert format_time(1_000) == "1.00μs" + + # Boundaries between microseconds and milliseconds + assert format_time(999_999) == "999μs" # This might round to 1000.00μs + assert format_time(1_000_000) == "1.00ms" + + # Boundaries between milliseconds and seconds + assert format_time(999_999_999) == "999ms" # This might round to 1000.00ms + assert format_time(1_000_000_000) == "1.00s" + + def test_precision_boundaries(self): + """Test precision changes at significant digit boundaries.""" + # Microseconds precision changes + assert format_time(9_950) == "9.95μs" # 2 decimal places + assert format_time(10_000) == "10.0μs" # 1 decimal place + assert format_time(99_900) == "99.9μs" # 1 decimal place + assert format_time(100_000) == "100μs" # No decimal places + + # Milliseconds precision changes + assert format_time(9_950_000) == "9.95ms" # 2 decimal places + assert format_time(10_000_000) == "10.0ms" # 1 decimal place + assert format_time(99_900_000) == "99.9ms" # 1 decimal place + assert format_time(100_000_000) == "100ms" # No decimal places + + # Seconds precision changes + assert format_time(9_950_000_000) == "9.95s" # 2 decimal places + assert format_time(10_000_000_000) == "10.0s" # 1 decimal place + assert format_time(99_900_000_000) == "99.9s" # 1 decimal place + assert format_time(100_000_000_000) == "100s" # No decimal places + + def test_rounding_behavior(self): + """Test rounding behavior for edge cases.""" + # Test rounding in microseconds + assert format_time(1_234) == "1.23μs" + assert format_time(1_235) == "1.24μs" # Should round up + assert format_time(12_345) == "12.3μs" + assert format_time(12_350) == "12.3μs" # Should round up + + # Test rounding in milliseconds + assert format_time(1_234_000) == "1.23ms" + assert format_time(1_235_000) == "1.24ms" # Should round up + assert format_time(12_345_000) == "12.3ms" + assert format_time(12_350_000) == "12.3ms" # Should round up + + def test_large_values(self): + """Test very large nanosecond values.""" + assert format_time(3_600_000_000_000) == "3600s" # 1 hour + assert format_time(86_400_000_000_000) == "86400s" # 1 day + + @pytest.mark.parametrize("nanoseconds,expected", [ + (0, "0ns"), + (42, "42ns"), + (1_500, "1.50μs"), + (25_000, "25.0μs"), + (150_000, "150μs"), + (2_500_000, "2.50ms"), + (45_000_000, "45.0ms"), + (200_000_000, "200ms"), + (3_500_000_000, "3.50s"), + (75_000_000_000, "75.0s"), + (300_000_000_000, "300s"), + ]) + def test_parametrized_examples(self, nanoseconds, expected): + """Parametrized test with various input/output combinations.""" + assert format_time(nanoseconds) == expected + + def test_invalid_input_types(self): + """Test that function handles invalid input types appropriately.""" + with pytest.raises(TypeError): + format_time("1000") + + with pytest.raises(TypeError): + format_time(1000.5) + + with pytest.raises(TypeError): + format_time(None) + + def test_negative_values(self): + """Test behavior with negative values (if applicable).""" + # This test depends on whether your function should handle negative values + # You might want to modify based on expected behavior + with pytest.raises((ValueError, TypeError)) or pytest.warns(): + format_time(-1000) \ No newline at end of file