Skip to content

Commit 902a982

Browse files
committed
calculate in own file time
remove unittests remnants
1 parent 03de4db commit 902a982

File tree

4 files changed

+45
-45
lines changed

4 files changed

+45
-45
lines changed

codeflash/benchmarking/function_ranker.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@ def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> lis
100100
"""Ranks and filters functions based on their ttX score and importance.
101101
102102
Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD
103-
of total runtime, then ranks the remaining functions by ttX score.
103+
of file-relative runtime, then ranks the remaining functions by ttX score.
104+
105+
Importance is calculated relative to functions in the same file(s) rather than
106+
total program time. This avoids filtering out functions due to test infrastructure
107+
overhead.
104108
105109
The ttX score prioritizes functions that are computationally heavy themselves
106110
or that make expensive calls to other functions.
@@ -116,9 +120,24 @@ def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> lis
116120
logger.warning("No function stats available to rank functions.")
117121
return []
118122

119-
total_program_time = sum(
120-
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
121-
)
123+
# Calculate total time from functions in the same file(s) as functions to optimize
124+
if functions_to_optimize:
125+
# Get unique files from functions to optimize
126+
target_files = {func.file_path.name for func in functions_to_optimize}
127+
# Calculate total time only from functions in these files
128+
total_program_time = sum(
129+
s["own_time_ns"]
130+
for s in self._function_stats.values()
131+
if s.get("own_time_ns", 0) > 0 and any(target_file in s["filename"] for target_file in target_files)
132+
)
133+
logger.debug(
134+
f"Using file-relative importance for {len(target_files)} file(s): {target_files}. "
135+
f"Total file time: {total_program_time:,} ns"
136+
)
137+
else:
138+
total_program_time = sum(
139+
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
140+
)
122141

123142
if total_program_time == 0:
124143
logger.warning("Total program time is zero, cannot determine function importance.")

codeflash/benchmarking/replay_test.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,23 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c
6666
def create_trace_replay_test_code(
6767
trace_file: str,
6868
functions_data: list[dict[str, Any]],
69-
test_framework: str = "pytest",
70-
max_run_count=256, # noqa: ANN001
69+
max_run_count: int = 256,
7170
) -> str:
7271
"""Create a replay test for functions based on trace data.
7372
7473
Args:
7574
----
7675
trace_file: Path to the SQLite database file
7776
functions_data: List of dictionaries with function info extracted from DB
78-
test_framework: 'pytest' or 'unittest'
7977
max_run_count: Maximum number of runs to include in the test
8078
8179
Returns:
8280
-------
8381
A string containing the test code
8482
8583
"""
86-
assert test_framework in ["pytest", "unittest"]
87-
8884
# Create Imports
89-
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
90-
{"import unittest" if test_framework == "unittest" else ""}
85+
imports = """from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
9186
from codeflash.benchmarking.replay_test import get_next_arg_and_return
9287
"""
9388

@@ -158,13 +153,7 @@ def create_trace_replay_test_code(
158153
)
159154

160155
# Create main body
161-
162-
if test_framework == "unittest":
163-
self = "self"
164-
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
165-
else:
166-
test_template = ""
167-
self = ""
156+
test_template = ""
168157

169158
for func in functions_data:
170159
module_name = func.get("module_name")
@@ -223,30 +212,28 @@ def create_trace_replay_test_code(
223212
filter_variables=filter_variables,
224213
)
225214

226-
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
215+
formatted_test_body = textwrap.indent(test_body, " ")
227216

228-
test_template += " " if test_framework == "unittest" else ""
229217
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
230-
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
218+
test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n"
231219

232220
return imports + "\n" + metadata + "\n" + test_template
233221

234222

235223
def generate_replay_test(
236-
trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100
224+
trace_file_path: Path, output_dir: Path, max_run_count: int = 100
237225
) -> int:
238226
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
239227
240228
Args:
241229
----
242230
trace_file_path: Path to the SQLite database file
243231
output_dir: Directory to write the generated tests (if None, only returns the code)
244-
test_framework: 'pytest' or 'unittest'
245232
max_run_count: Maximum number of runs to include per function
246233
247234
Returns:
248235
-------
249-
Dictionary mapping benchmark names to generated test code
236+
The number of replay tests generated
250237
251238
"""
252239
count = 0
@@ -295,7 +282,6 @@ def generate_replay_test(
295282
test_code = create_trace_replay_test_code(
296283
trace_file=trace_file_path.as_posix(),
297284
functions_data=functions_data,
298-
test_framework=test_framework,
299285
max_run_count=max_run_count,
300286
)
301287
test_code = sort_imports(code=test_code)

codeflash/tracing/replay_test.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,9 @@ def get_function_alias(module: str, function_name: str) -> str:
4646
def create_trace_replay_test(
4747
trace_file: str,
4848
functions: list[FunctionModules],
49-
test_framework: str = "pytest",
50-
max_run_count=100, # noqa: ANN001
49+
max_run_count: int = 100,
5150
) -> str:
52-
assert test_framework in {"pytest", "unittest"}
53-
54-
imports = f"""import dill as pickle
55-
{"import unittest" if test_framework == "unittest" else ""}
51+
imports = """import dill as pickle
5652
from codeflash.tracing.replay_test import get_next_arg_and_return
5753
"""
5854

@@ -112,12 +108,7 @@ def create_trace_replay_test(
112108
ret = {class_name_alias}{method_name}(**args)
113109
"""
114110
)
115-
if test_framework == "unittest":
116-
self = "self"
117-
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
118-
else:
119-
test_template = ""
120-
self = ""
111+
test_template = ""
121112
for func, func_property in zip(functions, function_properties):
122113
if func_property is None:
123114
continue
@@ -167,9 +158,8 @@ def create_trace_replay_test(
167158
max_run_count=max_run_count,
168159
filter_variables=filter_variables,
169160
)
170-
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
161+
formatted_test_body = textwrap.indent(test_body, " ")
171162

172-
test_template += " " if test_framework == "unittest" else ""
173-
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
163+
test_template += f"def test_{alias}():\n{formatted_test_body}\n"
174164

175165
return imports + "\n" + metadata + "\n" + test_template

codeflash/tracing/tracing_new_process.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def __init__(
110110
self._db_lock = threading.Lock()
111111

112112
self.con = None
113-
self.output_file = Path(output).resolve()
114113
self.functions = functions
115114
self.function_modules: list[FunctionModules] = []
116115
self.function_count = defaultdict(int)
@@ -126,6 +125,14 @@ def __init__(
126125
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
127126

128127
self.sanitized_filename = self.sanitize_to_filename(command)
128+
# Place trace file next to replay tests in the tests directory
129+
from codeflash.verification.verification_utils import get_test_file_path
130+
function_path = "_".join(functions) if functions else self.sanitized_filename
131+
test_file_path = get_test_file_path(
132+
test_dir=Path(config["tests_root"]), function_name=function_path, test_type="replay"
133+
)
134+
trace_filename = test_file_path.stem + ".trace"
135+
self.output_file = test_file_path.parent / trace_filename
129136
self.result_pickle_file_path = result_pickle_file_path
130137

131138
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
@@ -142,7 +149,6 @@ def __init__(
142149
self.timer = time.process_time_ns
143150
self.total_tt = 0
144151
self.simulate_call("profiler")
145-
assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file"
146152
self.t = self.timer()
147153

148154
# Store command information for metadata table
@@ -275,7 +281,6 @@ def __exit__(
275281
replay_test = create_trace_replay_test(
276282
trace_file=self.output_file,
277283
functions=self.function_modules,
278-
test_framework=self.config["test_framework"],
279284
max_run_count=self.max_function_count,
280285
)
281286
function_path = "_".join(self.functions) if self.functions else self.sanitized_filename
@@ -770,11 +775,11 @@ def make_pstats_compatible(self) -> None:
770775
self.files = []
771776
self.top_level = []
772777
new_stats = {}
773-
for func, (cc, ns, tt, ct, callers) in self.stats.items():
778+
for func, (cc, ns, tt, ct, callers) in list(self.stats.items()):
774779
new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()}
775780
new_stats[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers)
776781
new_timings = {}
777-
for func, (cc, ns, tt, ct, callers) in self.timings.items():
782+
for func, (cc, ns, tt, ct, callers) in list(self.timings.items()):
778783
new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()}
779784
new_timings[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers)
780785
self.stats = new_stats

0 commit comments

Comments
 (0)