Skip to content

Commit 291ac2c

Browse files
committed
fix test
ruff reformat and fix linting
1 parent d0de40b commit 291ac2c

16 files changed

+71
-54
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None:
2525
"""Set up the database connection for direct writing.
2626
2727
Args:
28+
----
2829
trace_path: Path to the trace database file
2930
3031
"""
@@ -52,6 +53,7 @@ def write_function_timings(self) -> None:
5253
"""Write function call data directly to the database.
5354
5455
Args:
56+
----
5557
data: List of function call data tuples to write
5658
5759
"""
@@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable:
9496
"""Use as a decorator to trace function execution.
9597
9698
Args:
99+
----
97100
func: The function to be decorated
98101
99102
Returns:
103+
-------
100104
The wrapped function
101105
102106
"""

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
7676
"""Add codeflash_trace to a function.
7777
7878
Args:
79+
----
7980
code: The source code as a string
8081
functions_to_optimize: List of FunctionToOptimize instances containing function details
8182
8283
Returns:
84+
-------
8385
The modified source code as a string
8486
8587
"""

codeflash/benchmarking/plugin/plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
7474
"""Process the trace file and extract timing data for all functions.
7575
7676
Args:
77+
----
7778
trace_path: Path to the trace file
7879
7980
Returns:
81+
-------
8082
A nested dictionary where:
8183
- Outer keys are module_name.qualified_name (module.class.function)
8284
- Inner keys are of type BenchmarkKey
@@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
132134
"""Extract total benchmark timings from trace files.
133135
134136
Args:
137+
----
135138
trace_path: Path to the trace file
136139
137140
Returns:
141+
-------
138142
A dictionary mapping where:
139143
- Keys are of type BenchmarkKey
140144
- Values are total benchmark timing in milliseconds (with overhead subtracted)

codeflash/benchmarking/replay_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def create_trace_replay_test_code(
5555
"""Create a replay test for functions based on trace data.
5656
5757
Args:
58+
----
5859
trace_file: Path to the SQLite database file
5960
functions_data: List of dictionaries with function info extracted from DB
6061
test_framework: 'pytest' or 'unittest'
6162
max_run_count: Maximum number of runs to include in the test
6263
6364
Returns:
65+
-------
6466
A string containing the test code
6567
6668
"""
@@ -218,12 +220,14 @@ def generate_replay_test(
218220
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
219221
220222
Args:
223+
----
221224
trace_file_path: Path to the SQLite database file
222225
output_dir: Directory to write the generated tests (if None, only returns the code)
223226
test_framework: 'pytest' or 'unittest'
224227
max_run_count: Maximum number of runs to include per function
225228
226229
Returns:
230+
-------
227231
Dictionary mapping benchmark names to generated test code
228232
229233
"""

codeflash/benchmarking/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ def process_benchmark_data(
8383
"""Process benchmark data and generate detailed benchmark information.
8484
8585
Args:
86+
----
8687
replay_performance_gain: The performance gain from replay
8788
fto_benchmark_timings: Function to optimize benchmark timings
8889
total_benchmark_timings: Total benchmark timings
8990
9091
Returns:
92+
-------
9193
ProcessedBenchmarkInfo containing processed benchmark details
9294
9395
"""

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from argparse import Namespace
3535

3636
CODEFLASH_LOGO: str = (
37-
f"{LF}" # noqa: ISC003
37+
f"{LF}"
3838
r" _ ___ _ _ " + f"{LF}"
3939
r" | | / __)| | | | " + f"{LF}"
4040
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"

codeflash/cli_cmds/logging_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
2727
],
2828
force=True,
2929
)
30-
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
30+
logging.info("Verbose DEBUG logging enabled")
3131
else:
32-
logging.info("Logging level set to INFO") # noqa: LOG015
32+
logging.info("Logging level set to INFO")
3333
console.rule()

codeflash/code_utils/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def add_function_to_checkpoint(
4747
"""Add a function to the checkpoint after it has been processed.
4848
4949
Args:
50+
----
5051
function_fully_qualified_name: The fully qualified name of the function
5152
status: Status of optimization (e.g., "optimized", "failed", "skipped")
5253
additional_info: Any additional information to store about the function
@@ -104,7 +105,8 @@ def cleanup(self) -> None:
104105
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
105106
"""Get information about all processed functions, regardless of status.
106107
107-
Returns:
108+
Returns
109+
-------
108110
Dictionary mapping function names to their processing information
109111
110112
"""

codeflash/code_utils/line_profile_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None:
2424
"""Initialize the transformer.
2525
2626
Args:
27+
----
2728
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
2829
decorator_name: The name of the decorator to add.
2930
@@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str,
144145
"""Add a decorator to a function with the exact qualified name in the source code.
145146
146147
Args:
148+
----
147149
module: The Python source code as a CST module.
148150
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
149151
decorator_name: The name of the decorator to add.
150152
151153
Returns:
154+
-------
152155
The modified CST module.
153156
154157
"""

codeflash/context/code_context_extractor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33
import os
44
from collections import defaultdict
55
from itertools import chain
6-
from pathlib import Path # noqa: TC003
6+
from pathlib import Path
77
from typing import TYPE_CHECKING
88

99
import libcst as cst
10-
from libcst import CSTNode # noqa: TC002
10+
from libcst import CSTNode
1111

1212
from codeflash.cli_cmds.console import logger
1313
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1414
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
1515
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
16-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
16+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1717
from codeflash.models.models import (
1818
CodeContextType,
1919
CodeOptimizationContext,
@@ -150,13 +150,15 @@ def extract_code_string_context_from_files(
150150
imports, and combines them.
151151
152152
Args:
153+
----
153154
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
154155
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
155156
project_root_path: Root path of the project
156157
remove_docstrings: Whether to remove docstrings from the extracted code
157158
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
158159
159160
Returns:
161+
-------
160162
CodeString containing the extracted code context with necessary imports
161163
162164
""" # noqa: D205
@@ -257,13 +259,15 @@ def extract_code_markdown_context_from_files(
257259
imports, and combines them into a structured markdown format.
258260
259261
Args:
262+
----
260263
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
261264
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
262265
project_root_path: Root path of the project
263266
remove_docstrings: Whether to remove docstrings from the extracted code
264267
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
265268
266269
Returns:
270+
-------
267271
CodeStringsMarkdown containing the extracted code context with necessary imports,
268272
formatted for inclusion in markdown
269273
@@ -502,7 +506,8 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
502506
) -> tuple[cst.CSTNode | None, bool]:
503507
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
504508
505-
Returns:
509+
Returns
510+
-------
506511
(filtered_node, found_target):
507512
filtered_node: The modified CST node or None if it should be removed.
508513
found_target: True if a target function was found in this node's subtree.
@@ -586,7 +591,8 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
586591
) -> tuple[cst.CSTNode | None, bool]:
587592
"""Recursively filter the node for read-only context.
588593
589-
Returns:
594+
Returns
595+
-------
590596
(filtered_node, found_target):
591597
filtered_node: The modified CST node or None if it should be removed.
592598
found_target: True if a target function was found in this node's subtree.
@@ -690,7 +696,8 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
690696
) -> tuple[cst.CSTNode | None, bool]:
691697
"""Recursively filter the node for testgen context.
692698
693-
Returns:
699+
Returns
700+
-------
694701
(filtered_node, found_target):
695702
filtered_node: The modified CST node or None if it should be removed.
696703
found_target: True if a target function was found in this node's subtree.

0 commit comments

Comments
 (0)