Skip to content

Commit cb64c68

Browse files
committed
tidy up
1 parent 7697ad5 commit cb64c68

File tree

3 files changed

+19
-34
lines changed

3 files changed

+19
-34
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from typing import Callable
88

9+
from codeflash.cli_cmds.cli import logger
910
from codeflash.picklepatch.pickle_patcher import PicklePatcher
1011

1112

@@ -42,10 +43,8 @@ def setup(self, trace_path: str) -> None:
4243
)
4344
self._connection.commit()
4445
except Exception as e:
45-
print(f"Database setup error: {e}")
46-
if self._connection:
47-
self._connection.close()
48-
self._connection = None
46+
logger.error(f"Database setup error: {e}")
47+
self.close()
4948
raise
5049

5150
def write_function_timings(self) -> None:
@@ -63,7 +62,6 @@ def write_function_timings(self) -> None:
6362

6463
try:
6564
cur = self._connection.cursor()
66-
# Insert data into the benchmark_function_timings table
6765
cur.executemany(
6866
"INSERT INTO benchmark_function_timings"
6967
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
@@ -74,7 +72,7 @@ def write_function_timings(self) -> None:
7472
self._connection.commit()
7573
self.function_calls_data = []
7674
except Exception as e:
77-
print(f"Error writing to function timings database: {e}")
75+
logger.error(f"Error writing to function timings database: {e}")
7876
if self._connection:
7977
self._connection.rollback()
8078
raise
@@ -103,7 +101,7 @@ def __call__(self, func: Callable) -> Callable:
103101
func_id = (func.__module__, func.__name__)
104102

105103
@functools.wraps(func)
106-
def wrapper(*args, **kwargs):
104+
def wrapper(*args: tuple, **kwargs: dict) -> object:
107105
# Initialize thread-local active functions set if it doesn't exist
108106
if not hasattr(self._thread_local, "active_functions"):
109107
self._thread_local.active_functions = set()
@@ -124,19 +122,17 @@ def wrapper(*args, **kwargs):
124122
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
125123
self._thread_local.active_functions.remove(func_id)
126124
return result
127-
# Get benchmark info from environment
125+
128126
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
129127
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
130128
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
131-
# Get class name
132129
class_name = ""
133130
qualname = func.__qualname__
134131
if "." in qualname:
135132
class_name = qualname.split(".")[0]
136133

137-
# Limit pickle count so memory does not explode
138134
if self.function_call_count > self.pickle_count_limit:
139-
print("Pickle limit reached")
135+
logger.debug("CodeflashTrace: Pickle limit reached")
140136
self._thread_local.active_functions.remove(func_id)
141137
overhead_time = time.thread_time_ns() - end_time
142138
self.function_calls_data.append(
@@ -161,7 +157,7 @@ def wrapper(*args, **kwargs):
161157
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
162158
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
163159
except Exception as e:
164-
print(f"Error pickling arguments for function {func.__name__}: {e}")
160+
logger.debug(f"CodeflashTrace: Error pickling arguments for function {func.__name__}: {e}")
165161
# Add to the list of function calls without pickled args. Used for timing info only
166162
self._thread_local.active_functions.remove(func_id)
167163
overhead_time = time.thread_time_ns() - end_time
@@ -181,7 +177,6 @@ def wrapper(*args, **kwargs):
181177
)
182178
)
183179
return result
184-
# Flush to database every 100 calls
185180
if len(self.function_calls_data) > 100:
186181
self.write_function_timings()
187182

@@ -208,5 +203,4 @@ def wrapper(*args, **kwargs):
208203
return wrapper
209204

210205

211-
# Create a singleton instance
212206
codeflash_trace = CodeflashTrace()

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,5 @@ def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, l
9393
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
9494
original_code = file_path.read_text(encoding="utf-8")
9595
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
96-
# Modify the code
9796
modified_code = isort.code(code=new_code, float_to_top=True)
98-
99-
# Write the modified code back to the file
10097
file_path.write_text(modified_code, encoding="utf-8")

codeflash/benchmarking/replay_test.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ def get_next_arg_and_return(
2828
limit = num_to_get
2929

3030
if class_name is not None:
31-
cursor = cur.execute(
32-
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
33-
(benchmark_function_name, function_name, file_path, class_name, limit),
34-
)
31+
query = "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" # noqa: E501
32+
cursor = cur.execute(query(benchmark_function_name, function_name, file_path, class_name, limit))
3533
else:
3634
cursor = cur.execute(
37-
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
35+
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?", # noqa: E501
3836
(benchmark_function_name, function_name, file_path, limit),
3937
)
4038

@@ -98,7 +96,7 @@ def create_trace_replay_test_code(
9896
args = pickle.loads(args_pkl)
9997
kwargs = pickle.loads(kwargs_pkl)
10098
ret = {function_name}(*args, **kwargs)
101-
"""
99+
""" # noqa: E501
102100
)
103101

104102
test_method_body = textwrap.dedent(
@@ -114,7 +112,7 @@ def create_trace_replay_test_code(
114112
else:
115113
instance = args[0] # self
116114
ret = instance{method_name}(*args[1:], **kwargs)
117-
"""
115+
""" # noqa: E501
118116
)
119117

120118
test_class_method_body = textwrap.dedent(
@@ -125,15 +123,15 @@ def create_trace_replay_test_code(
125123
if not args:
126124
raise ValueError("No arguments provided for the method.")
127125
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
128-
"""
126+
""" # noqa: E501
129127
)
130128
test_static_method_body = textwrap.dedent(
131129
"""\
132130
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
133131
args = pickle.loads(args_pkl)
134132
kwargs = pickle.loads(kwargs_pkl){filter_variables}
135133
ret = {class_name_alias}{method_name}(*args, **kwargs)
136-
"""
134+
""" # noqa: E501
137135
)
138136

139137
# Create main body
@@ -166,7 +164,7 @@ def create_trace_replay_test_code(
166164
alias = get_function_alias(module_name, class_name + "_" + function_name)
167165

168166
filter_variables = ""
169-
# filter_variables = '\n args.pop("cls", None)'
167+
# filter_variables = '\n args.pop("cls", None)' # noqa: ERA001
170168
method_name = "." + function_name if function_name != "__init__" else ""
171169
if function_properties.is_classmethod:
172170
test_body = test_class_method_body.format(
@@ -227,11 +225,9 @@ def generate_replay_test(
227225
"""
228226
count = 0
229227
try:
230-
# Connect to the database
231228
conn = sqlite3.connect(trace_file_path.as_posix())
232229
cursor = conn.cursor()
233230

234-
# Get distinct benchmark file paths
235231
cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings")
236232
benchmark_files = cursor.fetchall()
237233

@@ -240,15 +236,14 @@ def generate_replay_test(
240236
benchmark_module_path = benchmark_file[0]
241237
# Get all benchmarks and functions associated with this file path
242238
cursor.execute(
243-
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
239+
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " # noqa: E501
244240
"WHERE benchmark_module_path = ?",
245241
(benchmark_module_path,),
246242
)
247243

248244
functions_data = []
249245
for row in cursor.fetchall():
250246
benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row
251-
# Add this function to our list
252247
functions_data.append(
253248
{
254249
"function_name": function_name,
@@ -278,13 +273,12 @@ def generate_replay_test(
278273
output_file = get_test_file_path(
279274
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
280275
)
281-
# Write test code to file, parents = true
282276
output_dir.mkdir(parents=True, exist_ok=True)
283277
output_file.write_text(test_code, "utf-8")
284278
count += 1
285279

286280
conn.close()
287-
except Exception as e:
288-
logger.info(f"Error generating replay tests: {e}")
281+
except Exception as e: # noqa: BLE001
282+
logger.error(f"Error generating replay test: {e}")
289283

290284
return count

0 commit comments

Comments
 (0)