Skip to content

Commit bc71618

Browse files
committed
cleanup fixes
1 parent 62e10b1 commit bc71618

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+495
-375
lines changed

codeflash/api/aiservice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def make_ai_service_request(
7373
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
7474
return response
7575

76-
def optimize_python_code(
76+
def optimize_python_code( # noqa: D417
7777
self,
7878
source_code: str,
7979
dependency_code: str,
@@ -139,7 +139,7 @@ def optimize_python_code(
139139
console.rule()
140140
return []
141141

142-
def optimize_python_code_line_profiler(
142+
def optimize_python_code_line_profiler( # noqa: D417
143143
self,
144144
source_code: str,
145145
dependency_code: str,
@@ -208,7 +208,7 @@ def optimize_python_code_line_profiler(
208208
console.rule()
209209
return []
210210

211-
def log_results(
211+
def log_results( # noqa: D417
212212
self,
213213
function_trace_id: str,
214214
speedup_ratio: dict[str, float | None] | None,
@@ -240,7 +240,7 @@ def log_results(
240240
except requests.exceptions.RequestException as e:
241241
logger.exception(f"Error logging features: {e}")
242242

243-
def generate_regression_tests(
243+
def generate_regression_tests( # noqa: D417
244244
self,
245245
source_code_being_tested: str,
246246
function_to_optimize: FunctionToOptimize,
@@ -307,7 +307,7 @@ def generate_regression_tests(
307307
error = response.json()["error"]
308308
logger.error(f"Error generating tests: {response.status_code} - {error}")
309309
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
310-
return None
310+
return None # noqa: TRY300
311311
except Exception:
312312
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
313313
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})

codeflash/benchmarking/codeflash_trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlite3
55
import threading
66
import time
7-
from typing import Callable
7+
from typing import Any, Callable
88

99
from codeflash.picklepatch.pickle_patcher import PicklePatcher
1010

@@ -103,7 +103,7 @@ def __call__(self, func: Callable) -> Callable:
103103
func_id = (func.__module__, func.__name__)
104104

105105
@functools.wraps(func)
106-
def wrapper(*args, **kwargs):
106+
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
107107
# Initialize thread-local active functions set if it doesn't exist
108108
if not hasattr(self._thread_local, "active_functions"):
109109
self._thread_local.active_functions = set()

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from pathlib import Path
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional, Union
24

35
import isort
46
import libcst as cst
57

6-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
8+
if TYPE_CHECKING:
9+
from pathlib import Path
10+
11+
from libcst import BaseStatement, ClassDef, FlattenSentinel, FunctionDef, RemovalSentinel
12+
13+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
714

815

916
class AddDecoratorTransformer(cst.CSTTransformer):
@@ -15,33 +22,35 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1522
self.function_name = ""
1623
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))
1724

18-
def leave_ClassDef(self, original_node, updated_node):
25+
def leave_ClassDef(
26+
self, original_node: ClassDef, updated_node: ClassDef
27+
) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]:
1928
if self.class_name == original_node.name.value:
2029
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
2130
return updated_node
2231

23-
def visit_ClassDef(self, node):
32+
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
2433
if self.class_name: # Don't go into nested class
2534
return False
26-
self.class_name = node.name.value
35+
self.class_name = node.name.value # noqa: RET503
2736

28-
def visit_FunctionDef(self, node):
37+
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
2938
if self.function_name: # Don't go into nested function
3039
return False
31-
self.function_name = node.name.value
40+
self.function_name = node.name.value # noqa: RET503
3241

33-
def leave_FunctionDef(self, original_node, updated_node):
42+
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
3443
if self.function_name == original_node.name.value:
3544
self.function_name = ""
3645
if (self.class_name, original_node.name.value) in self.target_functions:
3746
# Add the new decorator after any existing decorators, so it gets executed first
38-
updated_decorators = list(updated_node.decorators) + [self.decorator]
47+
updated_decorators = [*list(updated_node.decorators), self.decorator]
3948
self.added_codeflash_trace = True
4049
return updated_node.with_changes(decorators=updated_decorators)
4150

4251
return updated_node
4352

44-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
53+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
4554
# Create import statement for codeflash_trace
4655
if not self.added_codeflash_trace:
4756
return updated_node
@@ -68,7 +77,7 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
6877
6978
Args:
7079
code: The source code as a string
71-
function_to_optimize: The FunctionToOptimize instance containing function details
80+
functions_to_optimize: List of FunctionToOptimize instances containing function details
7281
7382
Returns:
7483
The modified source code as a string

codeflash/benchmarking/plugin/plugin.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
186186

187187
# Pytest hooks
188188
@pytest.hookimpl
189-
def pytest_sessionfinish(self, session, exitstatus):
189+
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, ARG002
190190
"""Execute after whole test run is completed."""
191191
# Write any remaining benchmark timings to the database
192192
codeflash_trace.close()
@@ -196,24 +196,24 @@ def pytest_sessionfinish(self, session, exitstatus):
196196
self.close()
197197

198198
@staticmethod
199-
def pytest_addoption(parser):
199+
def pytest_addoption(parser: pytest.Parser) -> None:
200200
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
201201

202202
@staticmethod
203-
def pytest_plugin_registered(plugin, manager):
203+
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
204204
# Not necessary since run with -p no:benchmark, but just in case
205205
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
206206
manager.unregister(plugin)
207207

208208
@staticmethod
209-
def pytest_configure(config):
209+
def pytest_configure(config: pytest.Config) -> None:
210210
"""Register the benchmark marker."""
211211
config.addinivalue_line(
212212
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
213213
)
214214

215215
@staticmethod
216-
def pytest_collection_modifyitems(config, items):
216+
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
217217
# Skip tests that don't have the benchmark fixture
218218
if not config.getoption("--codeflash-trace"):
219219
return
@@ -235,30 +235,30 @@ def pytest_collection_modifyitems(config, items):
235235
item.add_marker(skip_no_benchmark)
236236

237237
# Benchmark fixture
238-
class Benchmark:
239-
def __init__(self, request):
238+
class Benchmark: # noqa: D106
239+
def __init__(self, request: pytest.FixtureRequest) -> None:
240240
self.request = request
241241

242-
def __call__(self, func, *args, **kwargs):
242+
def __call__(self, func, *args, **kwargs): # type: ignore # noqa: ANN001, ANN002, ANN003, ANN204, PGH003
243243
"""Handle both direct function calls and decorator usage."""
244244
if args or kwargs:
245245
# Used as benchmark(func, *args, **kwargs)
246246
return self._run_benchmark(func, *args, **kwargs)
247247

248248
# Used as @benchmark decorator
249-
def wrapped_func(*args, **kwargs):
249+
def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
250250
return func(*args, **kwargs)
251251

252-
result = self._run_benchmark(func)
252+
self._run_benchmark(func)
253253
return wrapped_func
254254

255-
def _run_benchmark(self, func, *args, **kwargs):
255+
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
256256
"""Actual benchmark implementation."""
257257
benchmark_module_path = module_name_from_file_path(
258258
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
259259
)
260260
benchmark_function_name = self.request.node.name
261-
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
261+
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
262262
# Set env vars
263263
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
264264
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
@@ -284,7 +284,7 @@ def _run_benchmark(self, func, *args, **kwargs):
284284

285285
@staticmethod
286286
@pytest.fixture
287-
def benchmark(request):
287+
def benchmark(request: pytest.FixtureRequest) -> object:
288288
if not request.config.getoption("--codeflash-trace"):
289289
return None
290290

codeflash/benchmarking/replay_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def get_function_alias(module: str, function_name: str) -> str:
4747

4848

4949
def create_trace_replay_test_code(
50-
trace_file: str, functions_data: list[dict[str, Any]], test_framework: str = "pytest", max_run_count=256
50+
trace_file: str,
51+
functions_data: list[dict[str, Any]],
52+
test_framework: str = "pytest",
53+
max_run_count=256, # noqa: ANN001
5154
) -> str:
5255
"""Create a replay test for functions based on trace data.
5356

codeflash/cli_cmds/cli_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwarg
4343
return func(*new_args, **new_kwargs)
4444

4545

46-
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]:
46+
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
4747
cli_width, _ = shutil.get_terminal_size()
4848
# split string to lines that accommodate "[?] " prefix
4949
cli_width -= len("[?] ")

codeflash/cli_cmds/cmd_init.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from argparse import Namespace
3535

3636
CODEFLASH_LOGO: str = (
37-
f"{LF}"
37+
f"{LF}" # noqa: ISC003
3838
r" _ ___ _ _ " + f"{LF}"
3939
r" | | / __)| | | | " + f"{LF}"
4040
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
@@ -126,7 +126,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
126126

127127

128128
def should_modify_pyproject_toml() -> bool:
129-
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
129+
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
130+
130131
If it does, ask the user if they want to re-configure it.
131132
"""
132133
from rich.prompt import Confirm
@@ -144,12 +145,11 @@ def should_modify_pyproject_toml() -> bool:
144145
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
145146
return True
146147

147-
create_toml = Confirm.ask(
148+
return Confirm.ask(
148149
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
149150
default=False,
150151
show_default=True,
151152
)
152-
return create_toml
153153

154154

155155
def collect_setup_info() -> SetupInfo:
@@ -469,7 +469,7 @@ def check_for_toml_or_setup_file() -> str | None:
469469
return cast("str", project_name)
470470

471471

472-
def install_github_actions(override_formatter_check: bool = False) -> None:
472+
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
473473
try:
474474
config, config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
475475

@@ -566,28 +566,22 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
566566

567567
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
568568
"""Determine which dependency manager is being used based on pyproject.toml contents."""
569+
result = DependencyManager.UNKNOWN
569570
if (Path.cwd() / "poetry.lock").exists():
570-
return DependencyManager.POETRY
571-
if (Path.cwd() / "uv.lock").exists():
572-
return DependencyManager.UV
573-
if "tool" not in pyproject_data:
574-
return DependencyManager.PIP
575-
576-
tool_section = pyproject_data["tool"]
577-
578-
# Check for poetry
579-
if "poetry" in tool_section:
580-
return DependencyManager.POETRY
581-
582-
# Check for uv
583-
if any(key.startswith("uv") for key in tool_section):
584-
return DependencyManager.UV
585-
586-
# Look for pip-specific markers
587-
if "pip" in tool_section or "setuptools" in tool_section:
588-
return DependencyManager.PIP
589-
590-
return DependencyManager.UNKNOWN
571+
result = DependencyManager.POETRY
572+
elif (Path.cwd() / "uv.lock").exists():
573+
result = DependencyManager.UV
574+
elif "tool" not in pyproject_data:
575+
result = DependencyManager.PIP
576+
else:
577+
tool_section = pyproject_data["tool"]
578+
if "poetry" in tool_section:
579+
result = DependencyManager.POETRY
580+
elif any(key.startswith("uv") for key in tool_section):
581+
result = DependencyManager.UV
582+
elif "pip" in tool_section or "setuptools" in tool_section:
583+
result = DependencyManager.PIP
584+
return result
591585

592586

593587
def get_codeflash_github_action_command(dep_manager: DependencyManager) -> str:
@@ -642,7 +636,10 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
642636

643637

644638
def customize_codeflash_yaml_content(
645-
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
639+
optimize_yml_content: str,
640+
config: tuple[dict[str, Any], Path],
641+
git_root: Path,
642+
benchmark_mode: bool = False, # noqa: FBT001, FBT002
646643
) -> str:
647644
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
648645
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
@@ -878,7 +875,7 @@ def test_sort(self):
878875
input = list(reversed(range(100)))
879876
output = sorter(input)
880877
self.assertEqual(output, list(range(100)))
881-
"""
878+
""" # noqa: PTH119
882879
elif args.test_framework == "pytest":
883880
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
884881
@@ -959,10 +956,8 @@ def ask_for_telemetry() -> bool:
959956
"""Prompt the user to enable or disable telemetry."""
960957
from rich.prompt import Confirm
961958

962-
enable_telemetry = Confirm.ask(
959+
return Confirm.ask(
963960
"⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?",
964961
default=True,
965962
show_default=True,
966963
)
967-
968-
return enable_telemetry

codeflash/cli_cmds/logging_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
2727
],
2828
force=True,
2929
)
30-
logging.info("Verbose DEBUG logging enabled")
30+
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
3131
else:
32-
logging.info("Logging level set to INFO")
32+
logging.info("Logging level set to INFO") # noqa: LOG015
3333
console.rule()

0 commit comments

Comments
 (0)