Skip to content

Commit 1a53bed

Browse files
committed
ruff format benchmarking
1 parent 9f967a7 commit 1a53bed

File tree

7 files changed

+155
-129
lines changed

7 files changed

+155
-129
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def write_function_timings(self) -> None:
6969
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
7070
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
7171
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
72-
self.function_calls_data
72+
self.function_calls_data,
7373
)
7474
self._connection.commit()
7575
self.function_calls_data = []
@@ -100,7 +100,8 @@ def __call__(self, func: Callable) -> Callable:
100100
The wrapped function
101101
102102
"""
103-
func_id = (func.__module__,func.__name__)
103+
func_id = (func.__module__, func.__name__)
104+
104105
@functools.wraps(func)
105106
def wrapper(*args, **kwargs):
106107
# Initialize thread-local active functions set if it doesn't exist
@@ -139,9 +140,19 @@ def wrapper(*args, **kwargs):
139140
self._thread_local.active_functions.remove(func_id)
140141
overhead_time = time.thread_time_ns() - end_time
141142
self.function_calls_data.append(
142-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
143-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
144-
overhead_time, None, None)
143+
(
144+
func.__name__,
145+
class_name,
146+
func.__module__,
147+
func.__code__.co_filename,
148+
benchmark_function_name,
149+
benchmark_module_path,
150+
benchmark_line_number,
151+
execution_time,
152+
overhead_time,
153+
None,
154+
None,
155+
)
145156
)
146157
return result
147158

@@ -155,9 +166,19 @@ def wrapper(*args, **kwargs):
155166
self._thread_local.active_functions.remove(func_id)
156167
overhead_time = time.thread_time_ns() - end_time
157168
self.function_calls_data.append(
158-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
159-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
160-
overhead_time, None, None)
169+
(
170+
func.__name__,
171+
class_name,
172+
func.__module__,
173+
func.__code__.co_filename,
174+
benchmark_function_name,
175+
benchmark_module_path,
176+
benchmark_line_number,
177+
execution_time,
178+
overhead_time,
179+
None,
180+
None,
181+
)
161182
)
162183
return result
163184
# Flush to database every 100 calls
@@ -168,12 +189,24 @@ def wrapper(*args, **kwargs):
168189
self._thread_local.active_functions.remove(func_id)
169190
overhead_time = time.thread_time_ns() - end_time
170191
self.function_calls_data.append(
171-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
172-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
173-
overhead_time, pickled_args, pickled_kwargs)
192+
(
193+
func.__name__,
194+
class_name,
195+
func.__module__,
196+
func.__code__.co_filename,
197+
benchmark_function_name,
198+
benchmark_module_path,
199+
benchmark_line_number,
200+
execution_time,
201+
overhead_time,
202+
pickled_args,
203+
pickled_kwargs,
204+
)
174205
)
175206
return result
207+
176208
return wrapper
177209

210+
178211
# Create a singleton instance
179212
codeflash_trace = CodeflashTrace()

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,46 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1313
self.added_codeflash_trace = False
1414
self.class_name = ""
1515
self.function_name = ""
16-
self.decorator = cst.Decorator(
17-
decorator=cst.Name(value="codeflash_trace")
18-
)
16+
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))
1917

20-
def leave_ClassDef(self, original_node, updated_node):
18+
def leave_ClassDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
2119
if self.class_name == original_node.name.value:
22-
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
20+
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
2321
return updated_node
2422

25-
def visit_ClassDef(self, node):
26-
if self.class_name: # Don't go into nested class
23+
def visit_ClassDef(self, node): # noqa: ANN001, ANN201, N802
24+
if self.class_name: # Don't go into nested class
2725
return False
28-
self.class_name = node.name.value
26+
self.class_name = node.name.value # noqa: RET503
2927

30-
def visit_FunctionDef(self, node):
31-
if self.function_name: # Don't go into nested function
28+
def visit_FunctionDef(self, node): # noqa: ANN001, ANN201, N802
29+
if self.function_name: # Don't go into nested function
3230
return False
33-
self.function_name = node.name.value
31+
self.function_name = node.name.value # noqa: RET503
3432

35-
def leave_FunctionDef(self, original_node, updated_node):
33+
def leave_FunctionDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802
3634
if self.function_name == original_node.name.value:
3735
self.function_name = ""
3836
if (self.class_name, original_node.name.value) in self.target_functions:
3937
# Add the new decorator after any existing decorators, so it gets executed first
40-
updated_decorators = list(updated_node.decorators) + [self.decorator]
38+
updated_decorators = [*list(updated_node.decorators), self.decorator]
4139
self.added_codeflash_trace = True
42-
return updated_node.with_changes(
43-
decorators=updated_decorators
44-
)
40+
return updated_node.with_changes(decorators=updated_decorators)
4541

4642
return updated_node
4743

48-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
44+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002, N802
4945
# Create import statement for codeflash_trace
5046
if not self.added_codeflash_trace:
5147
return updated_node
5248
import_stmt = cst.SimpleStatementLine(
5349
body=[
5450
cst.ImportFrom(
5551
module=cst.Attribute(
56-
value=cst.Attribute(
57-
value=cst.Name(value="codeflash"),
58-
attr=cst.Name(value="benchmarking")
59-
),
60-
attr=cst.Name(value="codeflash_trace")
52+
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
53+
attr=cst.Name(value="codeflash_trace"),
6154
),
62-
names=[
63-
cst.ImportAlias(
64-
name=cst.Name(value="codeflash_trace")
65-
)
66-
]
55+
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
6756
)
6857
]
6958
)
@@ -73,12 +62,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
7362

7463
return updated_node.with_changes(body=new_body)
7564

65+
7666
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
7767
"""Add codeflash_trace to a function.
7868
7969
Args:
8070
code: The source code as a string
81-
function_to_optimize: The FunctionToOptimize instance containing function details
71+
functions_to_optimize: List of FunctionToOptimize instances containing function details
8272
8373
Returns:
8474
The modified source code as a string
@@ -91,25 +81,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
9181
class_name = function_to_optimize.parents[0].name
9282
target_functions.add((class_name, function_to_optimize.function_name))
9383

94-
transformer = AddDecoratorTransformer(
95-
target_functions = target_functions,
96-
)
84+
transformer = AddDecoratorTransformer(target_functions=target_functions)
9785

9886
module = cst.parse_module(code)
9987
modified_module = module.visit(transformer)
10088
return modified_module.code
10189

10290

103-
def instrument_codeflash_trace_decorator(
104-
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
105-
) -> None:
91+
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
10692
"""Instrument codeflash_trace decorator to functions to optimize."""
10793
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
10894
original_code = file_path.read_text(encoding="utf-8")
109-
new_code = add_codeflash_decorator_to_code(
110-
original_code,
111-
functions_to_optimize
112-
)
95+
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
11396
# Modify the code
11497
modified_code = isort.code(code=new_code, float_to_top=True)
11598

codeflash/benchmarking/plugin/plugin.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self) -> None:
2020
self.project_root = None
2121
self.benchmark_timings = []
2222

23-
def setup(self, trace_path:str, project_root:str) -> None:
23+
def setup(self, trace_path: str, project_root: str) -> None:
2424
try:
2525
# Open connection
2626
self.project_root = project_root
@@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3535
"benchmark_time_ns INTEGER)"
3636
)
3737
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
38+
self.close() # Reopen only at the end of pytest session
3939
except Exception as e:
4040
print(f"Database setup error: {e}")
4141
if self._connection:
@@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None:
5555
# Insert data into the benchmark_timings table
5656
cur.executemany(
5757
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
58-
self.benchmark_timings
58+
self.benchmark_timings,
5959
)
6060
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
61+
self.benchmark_timings = [] # Clear the benchmark timings list
6262
except Exception as e:
6363
print(f"Error writing to benchmark timings database: {e}")
6464
self._connection.rollback()
6565
raise
66+
6667
def close(self) -> None:
6768
if self._connection:
6869
self._connection.close()
@@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196197

197198
@staticmethod
198199
def pytest_addoption(parser):
199-
parser.addoption(
200-
"--codeflash-trace",
201-
action="store_true",
202-
default=False,
203-
help="Enable CodeFlash tracing"
204-
)
200+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205201

206202
@staticmethod
207203
def pytest_plugin_registered(plugin, manager):
@@ -213,9 +209,9 @@ def pytest_plugin_registered(plugin, manager):
213209
def pytest_configure(config):
214210
"""Register the benchmark marker."""
215211
config.addinivalue_line(
216-
"markers",
217-
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
212+
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
218213
)
214+
219215
@staticmethod
220216
def pytest_collection_modifyitems(config, items):
221217
# Skip tests that don't have the benchmark fixture
@@ -248,16 +244,18 @@ def __call__(self, func, *args, **kwargs):
248244
if args or kwargs:
249245
# Used as benchmark(func, *args, **kwargs)
250246
return self._run_benchmark(func, *args, **kwargs)
247+
251248
# Used as @benchmark decorator
252-
def wrapped_func(*args, **kwargs):
253-
return func(*args, **kwargs)
254-
result = self._run_benchmark(func)
249+
def wrapped_func(*inner_args, **inner_kwargs):
250+
return self._run_benchmark(func, *inner_args, **inner_kwargs)
251+
255252
return wrapped_func
256253

257254
def _run_benchmark(self, func, *args, **kwargs):
258255
"""Actual benchmark implementation."""
259-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
260-
Path(codeflash_benchmark_plugin.project_root))
256+
benchmark_module_path = module_name_from_file_path(
257+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
258+
)
261259
benchmark_function_name = self.request.node.name
262260
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
263261
# Set env vars
@@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs):
278276
codeflash_trace.function_call_count = 0
279277
# Add to the benchmark timings buffer
280278
codeflash_benchmark_plugin.benchmark_timings.append(
281-
(benchmark_module_path, benchmark_function_name, line_number, end - start))
279+
(benchmark_module_path, benchmark_function_name, line_number, end - start)
280+
)
282281

283282
return result
284283

@@ -290,4 +289,5 @@ def benchmark(request):
290289

291290
return CodeFlashBenchmarkPlugin.Benchmark(request)
292291

292+
293293
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,20 @@
1616
codeflash_benchmark_plugin.setup(trace_file, project_root)
1717
codeflash_trace.setup(trace_file)
1818
exitcode = pytest.main(
19-
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
20-
) # Errors will be printed to stdout, not stderr
19+
[
20+
benchmarks_root,
21+
"--codeflash-trace",
22+
"-p",
23+
"no:benchmark",
24+
"-p",
25+
"no:codspeed",
26+
"-p",
27+
"no:cov-s",
28+
"-o",
29+
"addopts=",
30+
],
31+
plugins=[codeflash_benchmark_plugin],
32+
) # Errors will be printed to stdout, not stderr
2133

2234
except Exception as e:
2335
print(f"Failed to collect tests: {e!s}", file=sys.stderr)

0 commit comments

Comments
 (0)