Skip to content

Commit e3a84bc

Browse files
authored
Merge branch 'main' into mp-test-processing
2 parents 6f3f7af + 4de9323 commit e3a84bc

File tree

3 files changed

+192
-43
lines changed

3 files changed

+192
-43
lines changed

codeflash/code_utils/time_utils.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,36 +53,35 @@ def humanize_runtime(time_in_ns: int) -> str:
5353

5454
def format_time(nanoseconds: int) -> str:
5555
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
56-
# Inlined significant digit check: >= 3 digits if value >= 100
56+
# Define conversion factors and units
57+
if not isinstance(nanoseconds, int):
58+
raise TypeError("Input must be an integer.")
59+
if nanoseconds < 0:
60+
raise ValueError("Input must be a positive integer.")
61+
conversions = [(1_000_000_000, "s"), (1_000_000, "ms"), (1_000, "μs"), (1, "ns")]
62+
63+
# Handle nanoseconds case directly (no decimal formatting needed)
5764
if nanoseconds < 1_000:
5865
return f"{nanoseconds}ns"
59-
if nanoseconds < 1_000_000:
60-
microseconds_int = nanoseconds // 1_000
61-
if microseconds_int >= 100:
62-
return f"{microseconds_int}μs"
63-
microseconds = nanoseconds / 1_000
64-
# Format with precision: 3 significant digits
65-
if microseconds >= 100:
66-
return f"{microseconds:.0f}μs"
67-
if microseconds >= 10:
68-
return f"{microseconds:.1f}μs"
69-
return f"{microseconds:.2f}μs"
70-
if nanoseconds < 1_000_000_000:
71-
milliseconds_int = nanoseconds // 1_000_000
72-
if milliseconds_int >= 100:
73-
return f"{milliseconds_int}ms"
74-
milliseconds = nanoseconds / 1_000_000
75-
if milliseconds >= 100:
76-
return f"{milliseconds:.0f}ms"
77-
if milliseconds >= 10:
78-
return f"{milliseconds:.1f}ms"
79-
return f"{milliseconds:.2f}ms"
80-
seconds_int = nanoseconds // 1_000_000_000
81-
if seconds_int >= 100:
82-
return f"{seconds_int}s"
83-
seconds = nanoseconds / 1_000_000_000
84-
if seconds >= 100:
85-
return f"{seconds:.0f}s"
86-
if seconds >= 10:
87-
return f"{seconds:.1f}s"
88-
return f"{seconds:.2f}s"
66+
67+
# Find appropriate unit
68+
for divisor, unit in conversions:
69+
if nanoseconds >= divisor:
70+
value = nanoseconds / divisor
71+
int_value = nanoseconds // divisor
72+
73+
# Use integer formatting for values >= 100
74+
if int_value >= 100:
75+
formatted_value = f"{int_value:.0f}"
76+
# Format with precision for 3 significant digits
77+
elif value >= 100:
78+
formatted_value = f"{value:.0f}"
79+
elif value >= 10:
80+
formatted_value = f"{value:.1f}"
81+
else:
82+
formatted_value = f"{value:.2f}"
83+
84+
return f"{formatted_value}{unit}"
85+
86+
# This should never be reached, but included for completeness
87+
return f"{nanoseconds}ns"

codeflash/context/unused_definition_remover.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
import ast
44
from collections import defaultdict
55
from dataclasses import dataclass, field
6-
from pathlib import Path
7-
from typing import Optional
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from pathlib import Path
10+
from typing import TYPE_CHECKING, Optional
811

912
import libcst as cst
1013

1114
from codeflash.cli_cmds.console import logger
1215
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
13-
from codeflash.models.models import CodeOptimizationContext, FunctionSource
16+
17+
if TYPE_CHECKING:
18+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
19+
from codeflash.models.models import CodeOptimizationContext, FunctionSource
1420

1521

1622
@dataclass
@@ -493,11 +499,12 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
493499

494500

495501
def revert_unused_helper_functions(
496-
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
502+
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
497503
) -> None:
498504
"""Revert unused helper functions back to their original definitions.
499505
500506
Args:
507+
project_root: project_root
501508
unused_helpers: List of unused helper functions to revert
502509
original_helper_code: Dictionary mapping file paths to their original code
503510
@@ -516,9 +523,6 @@ def revert_unused_helper_functions(
516523
for file_path, helpers_in_file in unused_helpers_by_file.items():
517524
if file_path in original_helper_code:
518525
try:
519-
# Read current file content
520-
current_code = file_path.read_text(encoding="utf8")
521-
522526
# Get original code for this file
523527
original_code = original_helper_code[file_path]
524528

@@ -557,7 +561,6 @@ def _analyze_imports_in_optimized_code(
557561
# Precompute a two-level dict: module_name -> func_name -> [helpers]
558562
helpers_by_file_and_func = defaultdict(dict)
559563
helpers_by_file = defaultdict(list) # preserved for "import module"
560-
helpers_append = helpers_by_file_and_func.setdefault
561564
for helper in code_context.helper_functions:
562565
jedi_type = helper.jedi_definition.type
563566
if jedi_type != "class":
@@ -606,11 +609,12 @@ def _analyze_imports_in_optimized_code(
606609

607610

608611
def detect_unused_helper_functions(
609-
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
612+
function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str
610613
) -> list[FunctionSource]:
611614
"""Detect helper functions that are no longer called by the optimized entrypoint function.
612615
613616
Args:
617+
function_to_optimize: The function to optimize
614618
code_context: The code optimization context containing helper functions
615619
optimized_code: The optimized code to analyze
616620
@@ -702,8 +706,9 @@ def detect_unused_helper_functions(
702706
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
703707
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
704708

705-
return unused_helpers
709+
ret_val = unused_helpers
706710

707711
except Exception as e:
708712
logger.debug(f"Error detecting unused helper functions: {e}")
709-
return []
713+
ret_val = []
714+
return ret_val

tests/test_humanize_time.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from codeflash.code_utils.time_utils import humanize_runtime
1+
from codeflash.code_utils.time_utils import humanize_runtime, format_time
2+
import pytest
23

34

45
def test_humanize_runtime():
@@ -28,3 +29,147 @@ def test_humanize_runtime():
2829
assert humanize_runtime(12345678912345) == "3.43 hours"
2930
assert humanize_runtime(98765431298760) == "1.14 days"
3031
assert humanize_runtime(197530862597520) == "2.29 days"
32+
33+
34+
class TestFormatTime:
35+
"""Test cases for the format_time function."""
36+
37+
def test_nanoseconds_range(self):
38+
"""Test formatting for nanoseconds (< 1,000 ns)."""
39+
assert format_time(0) == "0ns"
40+
assert format_time(1) == "1ns"
41+
assert format_time(500) == "500ns"
42+
assert format_time(999) == "999ns"
43+
44+
def test_microseconds_range(self):
45+
"""Test formatting for microseconds (1,000 ns to 999,999 ns)."""
46+
# Integer microseconds >= 100
47+
# assert format_time(100_000) == "100μs"
48+
# assert format_time(500_000) == "500μs"
49+
# assert format_time(999_000) == "999μs"
50+
51+
# Decimal microseconds with varying precision
52+
assert format_time(1_000) == "1.00μs" # 1.0 μs, 2 decimal places
53+
assert format_time(1_500) == "1.50μs" # 1.5 μs, 2 decimal places
54+
assert format_time(9_999) == "10.00μs" # 9.999 μs rounds to 10.00
55+
assert format_time(10_000) == "10.0μs" # 10.0 μs, 1 decimal place
56+
assert format_time(15_500) == "15.5μs" # 15.5 μs, 1 decimal place
57+
assert format_time(99_900) == "99.9μs" # 99.9 μs, 1 decimal place
58+
59+
def test_milliseconds_range(self):
60+
"""Test formatting for milliseconds (1,000,000 ns to 999,999,999 ns)."""
61+
# Integer milliseconds >= 100
62+
assert format_time(100_000_000) == "100ms"
63+
assert format_time(500_000_000) == "500ms"
64+
assert format_time(999_000_000) == "999ms"
65+
66+
# Decimal milliseconds with varying precision
67+
assert format_time(1_000_000) == "1.00ms" # 1.0 ms, 2 decimal places
68+
assert format_time(1_500_000) == "1.50ms" # 1.5 ms, 2 decimal places
69+
assert format_time(9_999_000) == "10.00ms" # 9.999 ms rounds to 10.00
70+
assert format_time(10_000_000) == "10.0ms" # 10.0 ms, 1 decimal place
71+
assert format_time(15_500_000) == "15.5ms" # 15.5 ms, 1 decimal place
72+
assert format_time(99_900_000) == "99.9ms" # 99.9 ms, 1 decimal place
73+
74+
def test_seconds_range(self):
75+
"""Test formatting for seconds (>= 1,000,000,000 ns)."""
76+
# Integer seconds >= 100
77+
assert format_time(100_000_000_000) == "100s"
78+
assert format_time(500_000_000_000) == "500s"
79+
assert format_time(999_000_000_000) == "999s"
80+
81+
# Decimal seconds with varying precision
82+
assert format_time(1_000_000_000) == "1.00s" # 1.0 s, 2 decimal places
83+
assert format_time(1_500_000_000) == "1.50s" # 1.5 s, 2 decimal places
84+
assert format_time(9_999_000_000) == "10.00s" # 9.999 s rounds to 10.00
85+
assert format_time(10_000_000_000) == "10.0s" # 10.0 s, 1 decimal place
86+
assert format_time(15_500_000_000) == "15.5s" # 15.5 s, 1 decimal place
87+
assert format_time(99_900_000_000) == "99.9s" # 99.9 s, 1 decimal place
88+
89+
def test_boundary_values(self):
90+
"""Test exact boundary values between units."""
91+
# Boundaries between nanoseconds and microseconds
92+
assert format_time(999) == "999ns"
93+
assert format_time(1_000) == "1.00μs"
94+
95+
# Boundaries between microseconds and milliseconds
96+
assert format_time(999_999) == "999μs" # This might round to 1000.00μs
97+
assert format_time(1_000_000) == "1.00ms"
98+
99+
# Boundaries between milliseconds and seconds
100+
assert format_time(999_999_999) == "999ms" # This might round to 1000.00ms
101+
assert format_time(1_000_000_000) == "1.00s"
102+
103+
def test_precision_boundaries(self):
104+
"""Test precision changes at significant digit boundaries."""
105+
# Microseconds precision changes
106+
assert format_time(9_950) == "9.95μs" # 2 decimal places
107+
assert format_time(10_000) == "10.0μs" # 1 decimal place
108+
assert format_time(99_900) == "99.9μs" # 1 decimal place
109+
assert format_time(100_000) == "100μs" # No decimal places
110+
111+
# Milliseconds precision changes
112+
assert format_time(9_950_000) == "9.95ms" # 2 decimal places
113+
assert format_time(10_000_000) == "10.0ms" # 1 decimal place
114+
assert format_time(99_900_000) == "99.9ms" # 1 decimal place
115+
assert format_time(100_000_000) == "100ms" # No decimal places
116+
117+
# Seconds precision changes
118+
assert format_time(9_950_000_000) == "9.95s" # 2 decimal places
119+
assert format_time(10_000_000_000) == "10.0s" # 1 decimal place
120+
assert format_time(99_900_000_000) == "99.9s" # 1 decimal place
121+
assert format_time(100_000_000_000) == "100s" # No decimal places
122+
123+
def test_rounding_behavior(self):
124+
"""Test rounding behavior for edge cases."""
125+
# Test rounding in microseconds
126+
assert format_time(1_234) == "1.23μs"
127+
assert format_time(1_235) == "1.24μs" # Should round up
128+
assert format_time(12_345) == "12.3μs"
129+
assert format_time(12_350) == "12.3μs" # Should round up
130+
131+
# Test rounding in milliseconds
132+
assert format_time(1_234_000) == "1.23ms"
133+
assert format_time(1_235_000) == "1.24ms" # Should round up
134+
assert format_time(12_345_000) == "12.3ms"
135+
assert format_time(12_350_000) == "12.3ms" # Should round up
136+
137+
def test_large_values(self):
138+
"""Test very large nanosecond values."""
139+
assert format_time(3_600_000_000_000) == "3600s" # 1 hour
140+
assert format_time(86_400_000_000_000) == "86400s" # 1 day
141+
142+
@pytest.mark.parametrize("nanoseconds,expected", [
143+
(0, "0ns"),
144+
(42, "42ns"),
145+
(1_500, "1.50μs"),
146+
(25_000, "25.0μs"),
147+
(150_000, "150μs"),
148+
(2_500_000, "2.50ms"),
149+
(45_000_000, "45.0ms"),
150+
(200_000_000, "200ms"),
151+
(3_500_000_000, "3.50s"),
152+
(75_000_000_000, "75.0s"),
153+
(300_000_000_000, "300s"),
154+
])
155+
def test_parametrized_examples(self, nanoseconds, expected):
156+
"""Parametrized test with various input/output combinations."""
157+
assert format_time(nanoseconds) == expected
158+
159+
def test_invalid_input_types(self):
160+
"""Test that function handles invalid input types appropriately."""
161+
with pytest.raises(TypeError):
162+
format_time("1000")
163+
164+
with pytest.raises(TypeError):
165+
format_time(1000.5)
166+
167+
with pytest.raises(TypeError):
168+
format_time(None)
169+
170+
def test_negative_values(self):
171+
"""Test behavior with negative values (if applicable)."""
172+
# This test depends on whether your function should handle negative values
173+
# You might want to modify based on expected behavior
174+
with pytest.raises((ValueError, TypeError)) or pytest.warns():
175+
format_time(-1000)

0 commit comments

Comments
 (0)