Skip to content

Commit 7673123

Browse files
committed
benchmarking tidy up
1 parent 9f967a7 commit 7673123

File tree

11 files changed

+312
-264
lines changed

11 files changed

+312
-264
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from typing import Callable
88

9+
from codeflash.cli_cmds.cli import logger
910
from codeflash.picklepatch.pickle_patcher import PicklePatcher
1011

1112

@@ -42,10 +43,8 @@ def setup(self, trace_path: str) -> None:
4243
)
4344
self._connection.commit()
4445
except Exception as e:
45-
print(f"Database setup error: {e}")
46-
if self._connection:
47-
self._connection.close()
48-
self._connection = None
46+
logger.error(f"Database setup error: {e}")
47+
self.close()
4948
raise
5049

5150
def write_function_timings(self) -> None:
@@ -63,18 +62,17 @@ def write_function_timings(self) -> None:
6362

6463
try:
6564
cur = self._connection.cursor()
66-
# Insert data into the benchmark_function_timings table
6765
cur.executemany(
6866
"INSERT INTO benchmark_function_timings"
6967
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
7068
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
7169
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
72-
self.function_calls_data
70+
self.function_calls_data,
7371
)
7472
self._connection.commit()
7573
self.function_calls_data = []
7674
except Exception as e:
77-
print(f"Error writing to function timings database: {e}")
75+
logger.error(f"Error writing to function timings database: {e}")
7876
if self._connection:
7977
self._connection.rollback()
8078
raise
@@ -100,9 +98,10 @@ def __call__(self, func: Callable) -> Callable:
10098
The wrapped function
10199
102100
"""
103-
func_id = (func.__module__,func.__name__)
101+
func_id = (func.__module__, func.__name__)
102+
104103
@functools.wraps(func)
105-
def wrapper(*args, **kwargs):
104+
def wrapper(*args: tuple, **kwargs: dict) -> object:
106105
# Initialize thread-local active functions set if it doesn't exist
107106
if not hasattr(self._thread_local, "active_functions"):
108107
self._thread_local.active_functions = set()
@@ -123,25 +122,33 @@ def wrapper(*args, **kwargs):
123122
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
124123
self._thread_local.active_functions.remove(func_id)
125124
return result
126-
# Get benchmark info from environment
125+
127126
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
128127
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
129128
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
130-
# Get class name
131129
class_name = ""
132130
qualname = func.__qualname__
133131
if "." in qualname:
134132
class_name = qualname.split(".")[0]
135133

136-
# Limit pickle count so memory does not explode
137134
if self.function_call_count > self.pickle_count_limit:
138-
print("Pickle limit reached")
135+
logger.debug("CodeflashTrace: Pickle limit reached")
139136
self._thread_local.active_functions.remove(func_id)
140137
overhead_time = time.thread_time_ns() - end_time
141138
self.function_calls_data.append(
142-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
143-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
144-
overhead_time, None, None)
139+
(
140+
func.__name__,
141+
class_name,
142+
func.__module__,
143+
func.__code__.co_filename,
144+
benchmark_function_name,
145+
benchmark_module_path,
146+
benchmark_line_number,
147+
execution_time,
148+
overhead_time,
149+
None,
150+
None,
151+
)
145152
)
146153
return result
147154

@@ -150,30 +157,50 @@ def wrapper(*args, **kwargs):
150157
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
151158
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
152159
except Exception as e:
153-
print(f"Error pickling arguments for function {func.__name__}: {e}")
160+
logger.debug(f"CodeflashTrace: Error pickling arguments for function {func.__name__}: {e}")
154161
# Add to the list of function calls without pickled args. Used for timing info only
155162
self._thread_local.active_functions.remove(func_id)
156163
overhead_time = time.thread_time_ns() - end_time
157164
self.function_calls_data.append(
158-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
159-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
160-
overhead_time, None, None)
165+
(
166+
func.__name__,
167+
class_name,
168+
func.__module__,
169+
func.__code__.co_filename,
170+
benchmark_function_name,
171+
benchmark_module_path,
172+
benchmark_line_number,
173+
execution_time,
174+
overhead_time,
175+
None,
176+
None,
177+
)
161178
)
162179
return result
163-
# Flush to database every 100 calls
164180
if len(self.function_calls_data) > 100:
165181
self.write_function_timings()
166182

167183
# Add to the list of function calls with pickled args, to be used for replay tests
168184
self._thread_local.active_functions.remove(func_id)
169185
overhead_time = time.thread_time_ns() - end_time
170186
self.function_calls_data.append(
171-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
172-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
173-
overhead_time, pickled_args, pickled_kwargs)
187+
(
188+
func.__name__,
189+
class_name,
190+
func.__module__,
191+
func.__code__.co_filename,
192+
benchmark_function_name,
193+
benchmark_module_path,
194+
benchmark_line_number,
195+
execution_time,
196+
overhead_time,
197+
pickled_args,
198+
pickled_kwargs,
199+
)
174200
)
175201
return result
202+
176203
return wrapper
177204

178-
# Create a singleton instance
205+
179206
codeflash_trace = CodeflashTrace()

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,46 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1313
self.added_codeflash_trace = False
1414
self.class_name = ""
1515
self.function_name = ""
16-
self.decorator = cst.Decorator(
17-
decorator=cst.Name(value="codeflash_trace")
18-
)
16+
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))
1917

20-
def leave_ClassDef(self, original_node, updated_node):
18+
def leave_ClassDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
2119
if self.class_name == original_node.name.value:
22-
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
20+
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
2321
return updated_node
2422

25-
def visit_ClassDef(self, node):
26-
if self.class_name: # Don't go into nested class
23+
def visit_ClassDef(self, node): # noqa: ANN001, ANN201, N802
24+
if self.class_name: # Don't go into nested class
2725
return False
28-
self.class_name = node.name.value
26+
self.class_name = node.name.value # noqa: RET503
2927

30-
def visit_FunctionDef(self, node):
31-
if self.function_name: # Don't go into nested function
28+
def visit_FunctionDef(self, node): # noqa: ANN001, ANN201, N802
29+
if self.function_name: # Don't go into nested function
3230
return False
33-
self.function_name = node.name.value
31+
self.function_name = node.name.value # noqa: RET503
3432

35-
def leave_FunctionDef(self, original_node, updated_node):
33+
def leave_FunctionDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
3634
if self.function_name == original_node.name.value:
3735
self.function_name = ""
3836
if (self.class_name, original_node.name.value) in self.target_functions:
3937
# Add the new decorator after any existing decorators, so it gets executed first
40-
updated_decorators = list(updated_node.decorators) + [self.decorator]
38+
updated_decorators = [*list(updated_node.decorators), self.decorator]
4139
self.added_codeflash_trace = True
42-
return updated_node.with_changes(
43-
decorators=updated_decorators
44-
)
40+
return updated_node.with_changes(decorators=updated_decorators)
4541

4642
return updated_node
4743

48-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
44+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002, N802
4945
# Create import statement for codeflash_trace
5046
if not self.added_codeflash_trace:
5147
return updated_node
5248
import_stmt = cst.SimpleStatementLine(
5349
body=[
5450
cst.ImportFrom(
5551
module=cst.Attribute(
56-
value=cst.Attribute(
57-
value=cst.Name(value="codeflash"),
58-
attr=cst.Name(value="benchmarking")
59-
),
60-
attr=cst.Name(value="codeflash_trace")
52+
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
53+
attr=cst.Name(value="codeflash_trace"),
6154
),
62-
names=[
63-
cst.ImportAlias(
64-
name=cst.Name(value="codeflash_trace")
65-
)
66-
]
55+
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
6756
)
6857
]
6958
)
@@ -73,12 +62,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
7362

7463
return updated_node.with_changes(body=new_body)
7564

65+
7666
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
7767
"""Add codeflash_trace to a function.
7868
7969
Args:
8070
code: The source code as a string
81-
function_to_optimize: The FunctionToOptimize instance containing function details
71+
functions_to_optimize: List of FunctionToOptimize instances containing function details
8272
8373
Returns:
8474
The modified source code as a string
@@ -91,27 +81,17 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
9181
class_name = function_to_optimize.parents[0].name
9282
target_functions.add((class_name, function_to_optimize.function_name))
9383

94-
transformer = AddDecoratorTransformer(
95-
target_functions = target_functions,
96-
)
84+
transformer = AddDecoratorTransformer(target_functions=target_functions)
9785

9886
module = cst.parse_module(code)
9987
modified_module = module.visit(transformer)
10088
return modified_module.code
10189

10290

103-
def instrument_codeflash_trace_decorator(
104-
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
105-
) -> None:
91+
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
10692
"""Instrument codeflash_trace decorator to functions to optimize."""
10793
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
10894
original_code = file_path.read_text(encoding="utf-8")
109-
new_code = add_codeflash_decorator_to_code(
110-
original_code,
111-
functions_to_optimize
112-
)
113-
# Modify the code
95+
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
11496
modified_code = isort.code(code=new_code, float_to_top=True)
115-
116-
# Write the modified code back to the file
11797
file_path.write_text(modified_code, encoding="utf-8")

0 commit comments

Comments
 (0)