Skip to content

Commit 1637810

Browse files
committed
works with multithreading, added test
1 parent d664040 commit 1637810

File tree

12 files changed

+185
-63
lines changed

12 files changed

+185
-63
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# from code_to_optimize.bubble_sort_codeflash_trace import sorter
2+
from code_to_optimize.bubble_sort_codeflash_trace import sorter
3+
import concurrent.futures
4+
5+
6+
def multithreaded_sorter(unsorted_lists: list[list[int]]) -> list[list[int]]:
7+
# Create a list to store results in the correct order
8+
sorted_lists = [None] * len(unsorted_lists)
9+
10+
# Use ThreadPoolExecutor to manage threads
11+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
12+
# Submit all sorting tasks and map them to their original indices
13+
future_to_index = {
14+
executor.submit(sorter, unsorted_list): i
15+
for i, unsorted_list in enumerate(unsorted_lists)
16+
}
17+
18+
# Collect results as they complete
19+
for future in concurrent.futures.as_completed(future_to_index):
20+
index = future_to_index[future]
21+
sorted_lists[index] = future.result()
22+
23+
return sorted_lists
Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from code_to_optimize.bubble_sort_codeflash_trace import sorter, Sorter
3+
from code_to_optimize.bubble_sort import sorter
44

55

66
def test_sort(benchmark):
@@ -11,10 +11,3 @@ def test_sort(benchmark):
1111
def test_sort2():
1212
result = sorter(list(reversed(range(500))))
1313
assert result == list(range(500))
14-
15-
def test_class_sort(benchmark):
16-
obj = Sorter(list(reversed(range(100))))
17-
result1 = benchmark(obj.sorter, 2)
18-
result2 = benchmark(Sorter.sort_class, list(reversed(range(100))))
19-
result3 = benchmark(Sorter.sort_static, list(reversed(range(100))))
20-
result4 = benchmark(Sorter, [1,2,3])

code_to_optimize/tests/pytest/benchmarks/test_process_and_sort.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from code_to_optimize.process_and_bubble_sort_codeflash_trace import compute_and_sort
2-
from code_to_optimize.bubble_sort_codeflash_trace import sorter
1+
from code_to_optimize.process_and_bubble_sort import compute_and_sort
2+
from code_to_optimize.bubble_sort import sorter
33
def test_compute_and_sort(benchmark):
44
result = benchmark(compute_and_sort, list(reversed(range(500))))
55
assert result == 62208.5
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from code_to_optimize.bubble_sort_multithread import multithreaded_sorter
2+
3+
def test_benchmark_sort(benchmark):
4+
benchmark(multithreaded_sorter, [list(range(1000)) for i in range (10)])
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
from code_to_optimize.bubble_sort_codeflash_trace import sorter, Sorter
4+
5+
6+
def test_sort(benchmark):
7+
result = benchmark(sorter, list(reversed(range(500))))
8+
assert result == list(range(500))
9+
10+
# This should not be picked up as a benchmark test
11+
def test_sort2():
12+
result = sorter(list(reversed(range(500))))
13+
assert result == list(range(500))
14+
15+
def test_class_sort(benchmark):
16+
obj = Sorter(list(reversed(range(100))))
17+
result1 = benchmark(obj.sorter, 2)
18+
result2 = benchmark(Sorter.sort_class, list(reversed(range(100))))
19+
result3 = benchmark(Sorter.sort_static, list(reversed(range(100))))
20+
result4 = benchmark(Sorter, [1,2,3])
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from code_to_optimize.process_and_bubble_sort_codeflash_trace import compute_and_sort
2+
from code_to_optimize.bubble_sort_codeflash_trace import sorter
3+
def test_compute_and_sort(benchmark):
4+
result = benchmark(compute_and_sort, list(reversed(range(500))))
5+
assert result == 62208.5
6+
7+
def test_no_func(benchmark):
8+
benchmark(sorter, list(reversed(range(500))))

codeflash/benchmarking/codeflash_trace.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ class CodeflashTrace:
1515
def __init__(self) -> None:
1616
self.function_calls_data = []
1717

18-
# def __enter__(self) -> None:
19-
# # Initialize for context manager use
20-
# self.function_calls_data = []
21-
# return self
22-
2318
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
2419
# Cleanup is optional here
2520
pass
@@ -37,15 +32,14 @@ def __call__(self, func: Callable) -> Callable:
3732
@functools.wraps(func)
3833
def wrapper(*args, **kwargs):
3934
# Measure execution time
40-
start_time = time.perf_counter_ns()
35+
start_time = time.thread_time_ns()
4136
result = func(*args, **kwargs)
42-
end_time = time.perf_counter_ns()
43-
37+
end_time = time.thread_time_ns()
4438
# Calculate execution time
4539
execution_time = end_time - start_time
4640

4741
# Measure overhead
48-
overhead_start_time = time.perf_counter_ns()
42+
overhead_start_time = time.thread_time_ns()
4943

5044
try:
5145
# Check if currently in pytest benchmark fixture
@@ -66,7 +60,7 @@ def wrapper(*args, **kwargs):
6660
if "." in qualname:
6761
class_name = qualname.split(".")[0]
6862
# Calculate overhead time
69-
overhead_end_time = time.perf_counter_ns()
63+
overhead_end_time = time.thread_time_ns()
7064
overhead_time = overhead_end_time - overhead_start_time
7165

7266

@@ -75,7 +69,7 @@ def wrapper(*args, **kwargs):
7569
benchmark_function_name, benchmark_file_name, benchmark_line_number, execution_time,
7670
overhead_time, pickled_args, pickled_kwargs)
7771
)
78-
72+
print("appended")
7973
except Exception as e:
8074
print(f"Error in codeflash_trace: {e}")
8175

codeflash/benchmarking/utils.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,64 @@
11
from rich.console import Console
22
from rich.table import Table
33

4+
from codeflash.cli_cmds.console import logger
45

5-
def print_benchmark_table(function_benchmark_timings: dict[str, dict[str, int]],
6-
total_benchmark_timings: dict[str, int]):
7-
console = Console()
86

7+
def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[str, int]],
8+
total_benchmark_timings: dict[str, int]) -> dict[str, list[tuple[str, float, float, float]]]:
9+
function_to_result = {}
910
# Process each function's benchmark data
1011
for func_path, test_times in function_benchmark_timings.items():
11-
function_name = func_path.split(":")[-1]
12-
13-
# Create a table for this function
14-
table = Table(title=f"Function: {function_name}", border_style="blue")
15-
16-
# Add columns
17-
table.add_column("Benchmark Test", style="cyan", no_wrap=True)
18-
table.add_column("Total Time (ms)", justify="right", style="green")
19-
table.add_column("Function Time (ms)", justify="right", style="yellow")
20-
table.add_column("Percentage (%)", justify="right", style="red")
21-
2212
# Sort by percentage (highest first)
2313
sorted_tests = []
2414
for test_name, func_time in test_times.items():
2515
total_time = total_benchmark_timings.get(test_name, 0)
16+
if func_time > total_time:
17+
logger.debug(f"Skipping test {test_name} due to func_time {func_time} > total_time {total_time}")
18+
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
19+
# Do not try to project the optimization impact for this function.
20+
sorted_tests.append((test_name, 0.0, 0.0, 0.0))
2621
if total_time > 0:
2722
percentage = (func_time / total_time) * 100
2823
# Convert nanoseconds to milliseconds
2924
func_time_ms = func_time / 1_000_000
3025
total_time_ms = total_time / 1_000_000
3126
sorted_tests.append((test_name, total_time_ms, func_time_ms, percentage))
32-
3327
sorted_tests.sort(key=lambda x: x[3], reverse=True)
28+
function_to_result[func_path] = sorted_tests
29+
return function_to_result
30+
31+
def print_benchmark_table(function_to_results: dict[str, list[tuple[str, float, float, float]]]) -> None:
32+
console = Console()
33+
for func_path, sorted_tests in function_to_results.items():
34+
function_name = func_path.split(":")[-1]
35+
36+
# Create a table for this function
37+
table = Table(title=f"Function: {function_name}", border_style="blue")
38+
39+
# Add columns
40+
table.add_column("Benchmark Test", style="cyan", no_wrap=True)
41+
table.add_column("Total Time (ms)", justify="right", style="green")
42+
table.add_column("Function Time (ms)", justify="right", style="yellow")
43+
table.add_column("Percentage (%)", justify="right", style="red")
3444

35-
# Add rows to the table
3645
for test_name, total_time, func_time, percentage in sorted_tests:
3746
benchmark_file, benchmark_func, benchmark_line = test_name.split("::")
3847
benchmark_name = f"{benchmark_file}::{benchmark_func}"
39-
table.add_row(
40-
benchmark_name,
41-
f"{total_time:.3f}",
42-
f"{func_time:.3f}",
43-
f"{percentage:.2f}"
44-
)
48+
if total_time == 0.0:
49+
table.add_row(
50+
benchmark_name,
51+
"N/A",
52+
"N/A",
53+
"N/A"
54+
)
55+
else:
56+
table.add_row(
57+
benchmark_name,
58+
f"{total_time:.3f}",
59+
f"{func_time:.3f}",
60+
f"{percentage:.2f}"
61+
)
4562

4663
# Print the table
4764
console.print(table)

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
8888
function_to_optimize_ast: ast.FunctionDef | None = None,
8989
aiservice_client: AiServiceClient | None = None,
90-
function_benchmark_timings: dict[str, dict[str, int]] | None = None,
90+
function_benchmark_timings: dict[str, int] | None = None,
9191
total_benchmark_timings: dict[str, int] | None = None,
9292
args: Namespace | None = None,
9393
) -> None:
@@ -272,7 +272,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
272272
function_name=function_to_optimize_qualified_name,
273273
file_path=self.function_to_optimize.file_path,
274274
replay_performance_gain=best_optimization.replay_performance_gain if self.args.benchmark else None,
275-
fto_benchmark_timings = self.function_benchmark_timings[self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)] if self.args.benchmark else None,
275+
fto_benchmark_timings = self.function_benchmark_timings if self.args.benchmark else None,
276276
total_benchmark_timings = self.total_benchmark_timings if self.args.benchmark else None,
277277
)
278278

codeflash/optimization/optimizer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
1111
from codeflash.benchmarking.replay_test import generate_replay_test
1212
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
13-
from codeflash.benchmarking.utils import print_benchmark_table
13+
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
1414
from codeflash.cli_cmds.console import console, logger, progress_bar
1515
from codeflash.code_utils import env_utils
16-
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
1716
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
1817
from codeflash.code_utils.code_utils import get_run_tmp_file
1918
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
@@ -115,15 +114,15 @@ def run(self) -> None:
115114
instrument_codeflash_trace_decorator(fto)
116115
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
117116
replay_tests_dir = Path(self.args.tests_root) / "codeflash_replay_tests"
118-
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Simply run all tests that use pytest-benchmark
117+
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
119118
replay_count = generate_replay_test(trace_file, replay_tests_dir)
120119
if replay_count == 0:
121120
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
122121
else:
123122
function_benchmark_timings = get_function_benchmark_timings(trace_file)
124123
total_benchmark_timings = get_benchmark_timings(trace_file)
125-
126-
print_benchmark_table(function_benchmark_timings, total_benchmark_timings)
124+
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
125+
print_benchmark_table(function_to_results)
127126
logger.info("Finished tracing existing benchmarks")
128127
except Exception as e:
129128
logger.info(f"Error while tracing existing benchmarks: {e}")
@@ -213,9 +212,12 @@ def run(self) -> None:
213212
f"Skipping optimization."
214213
)
215214
continue
216-
if self.args.benchmark and function_benchmark_timings and total_benchmark_timings:
215+
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
216+
self.args.project_root
217+
)
218+
if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings:
217219
function_optimizer = self.create_function_optimizer(
218-
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings, total_benchmark_timings
220+
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings
219221
)
220222
else:
221223
function_optimizer = self.create_function_optimizer(

0 commit comments

Comments
 (0)