Skip to content

Commit 6b7c435

Browse files
committed
cleanup code
1 parent 9d005b1 commit 6b7c435

File tree

4 files changed

+9
-21
lines changed

4 files changed

+9
-21
lines changed

codeflash/benchmarking/function_ranker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from codeflash.cli_cmds.console import console, logger
5+
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
77
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
88
from codeflash.tracing.profile_stats import ProfileStats
@@ -128,7 +128,8 @@ def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> lis
128128
total_program_time = sum(
129129
s["own_time_ns"]
130130
for s in self._function_stats.values()
131-
if s.get("own_time_ns", 0) > 0 and any(
131+
if s.get("own_time_ns", 0) > 0
132+
and any(
132133
str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file
133134
for target_file in target_files
134135
)

codeflash/benchmarking/replay_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c
6464

6565

6666
def create_trace_replay_test_code(
67-
trace_file: str,
68-
functions_data: list[dict[str, Any]],
69-
max_run_count: int = 256,
67+
trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256
7068
) -> str:
7169
"""Create a replay test for functions based on trace data.
7270
@@ -220,9 +218,7 @@ def create_trace_replay_test_code(
220218
return imports + "\n" + metadata + "\n" + test_template
221219

222220

223-
def generate_replay_test(
224-
trace_file_path: Path, output_dir: Path, max_run_count: int = 100
225-
) -> int:
221+
def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int:
226222
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
227223
228224
Args:
@@ -280,9 +276,7 @@ def generate_replay_test(
280276
continue
281277
# Generate the test code for this benchmark
282278
test_code = create_trace_replay_test_code(
283-
trace_file=trace_file_path.as_posix(),
284-
functions_data=functions_data,
285-
max_run_count=max_run_count,
279+
trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count
286280
)
287281
test_code = sort_imports(code=test_code)
288282
output_file = get_test_file_path(

codeflash/tracing/replay_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ def get_function_alias(module: str, function_name: str) -> str:
4343
return "_".join(module.split(".")) + "_" + function_name
4444

4545

46-
def create_trace_replay_test(
47-
trace_file: str,
48-
functions: list[FunctionModules],
49-
max_run_count: int = 100,
50-
) -> str:
46+
def create_trace_replay_test(trace_file: str, functions: list[FunctionModules], max_run_count: int = 100) -> str:
5147
imports = """import dill as pickle
5248
from codeflash.tracing.replay_test import get_next_arg_and_return
5349
"""

codeflash/tracing/tracing_new_process.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(
7070
self,
7171
config: dict,
7272
result_pickle_file_path: Path,
73-
output: str = "codeflash.trace",
7473
functions: list[str] | None = None,
7574
disable: bool = False, # noqa: FBT001, FBT002
7675
project_root: Path | None = None,
@@ -80,7 +79,6 @@ def __init__(
8079
) -> None:
8180
"""Use this class to trace function calls.
8281
83-
:param output: The path to the output trace file
8482
:param functions: List of functions to trace. If None, trace all functions
8583
:param disable: Disable the tracer if True
8684
:param max_function_count: Maximum number of times to trace one function
@@ -127,6 +125,7 @@ def __init__(
127125
self.sanitized_filename = self.sanitize_to_filename(command)
128126
# Place trace file next to replay tests in the tests directory
129127
from codeflash.verification.verification_utils import get_test_file_path
128+
130129
function_path = "_".join(functions) if functions else self.sanitized_filename
131130
test_file_path = get_test_file_path(
132131
test_dir=Path(config["tests_root"]), function_name=function_path, test_type="replay"
@@ -279,9 +278,7 @@ def __exit__(
279278
from codeflash.verification.verification_utils import get_test_file_path
280279

281280
replay_test = create_trace_replay_test(
282-
trace_file=self.output_file,
283-
functions=self.function_modules,
284-
max_run_count=self.max_function_count,
281+
trace_file=self.output_file, functions=self.function_modules, max_run_count=self.max_function_count
285282
)
286283
function_path = "_".join(self.functions) if self.functions else self.sanitized_filename
287284
test_file_path = get_test_file_path(

0 commit comments

Comments
 (0)