Skip to content

Commit f513763

Browse files
committed
pass mypy & ruff
1 parent 72f3022 commit f513763

File tree

8 files changed

+105
-88
lines changed

8 files changed

+105
-88
lines changed

codeflash/api/aiservice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def make_ai_service_request(
7373
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
7474
return response
7575

76-
def optimize_python_code(
76+
def optimize_python_code( # noqa: D417
7777
self,
7878
source_code: str,
7979
dependency_code: str,
@@ -139,7 +139,7 @@ def optimize_python_code(
139139
console.rule()
140140
return []
141141

142-
def optimize_python_code_line_profiler(
142+
def optimize_python_code_line_profiler( # noqa: D417
143143
self,
144144
source_code: str,
145145
dependency_code: str,
@@ -208,7 +208,7 @@ def optimize_python_code_line_profiler(
208208
console.rule()
209209
return []
210210

211-
def log_results(
211+
def log_results( # noqa: D417
212212
self,
213213
function_trace_id: str,
214214
speedup_ratio: dict[str, float | None] | None,
@@ -240,7 +240,7 @@ def log_results(
240240
except requests.exceptions.RequestException as e:
241241
logger.exception(f"Error logging features: {e}")
242242

243-
def generate_regression_tests(
243+
def generate_regression_tests( # noqa: D417
244244
self,
245245
source_code_being_tested: str,
246246
function_to_optimize: FunctionToOptimize,
@@ -307,7 +307,7 @@ def generate_regression_tests(
307307
error = response.json()["error"]
308308
logger.error(f"Error generating tests: {response.status_code} - {error}")
309309
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
310-
return None
310+
return None # noqa: TRY300
311311
except Exception:
312312
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
313313
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})

codeflash/cli_cmds/cmd_init.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from argparse import Namespace
3535

3636
CODEFLASH_LOGO: str = (
37-
f"{LF}"
37+
f"{LF}" # noqa: ISC003
3838
r" _ ___ _ _ " + f"{LF}"
3939
r" | | / __)| | | | " + f"{LF}"
4040
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
@@ -126,7 +126,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
126126

127127

128128
def should_modify_pyproject_toml() -> bool:
129-
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
129+
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
130+
130131
If it does, ask the user if they want to re-configure it.
131132
"""
132133
from rich.prompt import Confirm
@@ -144,12 +145,11 @@ def should_modify_pyproject_toml() -> bool:
144145
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
145146
return True
146147

147-
create_toml = Confirm.ask(
148+
return Confirm.ask(
148149
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
149150
default=False,
150151
show_default=True,
151152
)
152-
return create_toml
153153

154154

155155
def collect_setup_info() -> SetupInfo:
@@ -469,7 +469,7 @@ def check_for_toml_or_setup_file() -> str | None:
469469
return cast("str", project_name)
470470

471471

472-
def install_github_actions(override_formatter_check: bool = False) -> None:
472+
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
473473
try:
474474
config, config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
475475

@@ -564,7 +564,7 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
564564
apologize_and_exit()
565565

566566

567-
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
567+
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager: # noqa: PLR0911
568568
"""Determine which dependency manager is being used based on pyproject.toml contents."""
569569
if (Path.cwd() / "poetry.lock").exists():
570570
return DependencyManager.POETRY
@@ -642,7 +642,10 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
642642

643643

644644
def customize_codeflash_yaml_content(
645-
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
645+
optimize_yml_content: str,
646+
config: tuple[dict[str, Any], Path],
647+
git_root: Path,
648+
benchmark_mode: bool = False, # noqa: FBT001, FBT002
646649
) -> str:
647650
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
648651
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
@@ -878,7 +881,7 @@ def test_sort(self):
878881
input = list(reversed(range(100)))
879882
output = sorter(input)
880883
self.assertEqual(output, list(range(100)))
881-
"""
884+
""" # noqa: PTH119
882885
elif args.test_framework == "pytest":
883886
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
884887
@@ -959,10 +962,8 @@ def ask_for_telemetry() -> bool:
959962
"""Prompt the user to enable or disable telemetry."""
960963
from rich.prompt import Confirm
961964

962-
enable_telemetry = Confirm.ask(
965+
return Confirm.ask(
963966
"⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?",
964967
default=True,
965968
show_default=True,
966969
)
967-
968-
return enable_telemetry

codeflash/code_utils/code_extractor.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
# ruff: noqa: ARG002
12
from __future__ import annotations
23

34
import ast
4-
from pathlib import Path
5-
from typing import TYPE_CHECKING, Dict, Optional, Set
5+
from typing import TYPE_CHECKING, Optional
66

77
import libcst as cst
88
import libcst.matchers as m
@@ -11,23 +11,24 @@
1111
from libcst.helpers import calculate_module_and_package
1212

1313
from codeflash.cli_cmds.console import logger
14-
from codeflash.models.models import FunctionParent, FunctionSource
14+
from codeflash.models.models import FunctionParent
1515

1616
if TYPE_CHECKING:
17+
from pathlib import Path
18+
1719
from libcst.helpers import ModuleNameAndPackage
1820

1921
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
20-
21-
from typing import List
22+
from codeflash.models.models import FunctionSource
2223

2324

2425
class GlobalAssignmentCollector(cst.CSTVisitor):
2526
"""Collects all global assignment statements."""
2627

27-
def __init__(self):
28+
def __init__(self) -> None:
2829
super().__init__()
29-
self.assignments: Dict[str, cst.Assign] = {}
30-
self.assignment_order: List[str] = []
30+
self.assignments: dict[str, cst.Assign] = {}
31+
self.assignment_order: list[str] = []
3132
# Track scope depth to identify global assignments
3233
self.scope_depth = 0
3334
self.if_else_depth = 0
@@ -72,11 +73,11 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7273
class GlobalAssignmentTransformer(cst.CSTTransformer):
7374
"""Transforms global assignments in the original file with those from the new file."""
7475

75-
def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
76+
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
7677
super().__init__()
7778
self.new_assignments = new_assignments
7879
self.new_assignment_order = new_assignment_order
79-
self.processed_assignments: Set[str] = set()
80+
self.processed_assignments: set[str] = set()
8081
self.scope_depth = 0
8182
self.if_else_depth = 0
8283

@@ -124,10 +125,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
124125
new_statements = list(updated_node.body)
125126

126127
# Find assignments to append
127-
assignments_to_append = []
128-
for name in self.new_assignment_order:
129-
if name not in self.processed_assignments and name in self.new_assignments:
130-
assignments_to_append.append(self.new_assignments[name])
128+
assignments_to_append = [
129+
self.new_assignments[name]
130+
for name in self.new_assignment_order
131+
if name not in self.processed_assignments and name in self.new_assignments
132+
]
131133

132134
if assignments_to_append:
133135
# Add a blank line before appending new assignments if needed
@@ -136,16 +138,20 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
136138
new_statements.pop() # Remove the Pass statement but keep the empty line
137139

138140
# Add the new assignments
139-
for assignment in assignments_to_append:
140-
new_statements.append(cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]))
141+
new_statements.extend(
142+
[
143+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
144+
for assignment in assignments_to_append
145+
]
146+
)
141147

142148
return updated_node.with_changes(body=new_statements)
143149

144150

145151
class GlobalStatementCollector(cst.CSTVisitor):
146152
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
147153

148-
def __init__(self):
154+
def __init__(self) -> None:
149155
super().__init__()
150156
self.global_statements = []
151157
self.in_function_or_class = False
@@ -178,7 +184,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
178184
class LastImportFinder(cst.CSTVisitor):
179185
"""Finds the position of the last import statement in the module."""
180186

181-
def __init__(self):
187+
def __init__(self) -> None:
182188
super().__init__()
183189
self.last_import_line = 0
184190
self.current_line = 0
@@ -193,7 +199,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
193199
class ImportInserter(cst.CSTTransformer):
194200
"""Transformer that inserts global statements after the last import."""
195201

196-
def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
202+
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
197203
super().__init__()
198204
self.global_statements = global_statements
199205
self.last_import_line = last_import_line
@@ -208,7 +214,7 @@ def leave_SimpleStatementLine(
208214
# If we're right after the last import and haven't inserted yet
209215
if self.current_line == self.last_import_line and not self.inserted:
210216
self.inserted = True
211-
return cst.Module(body=[updated_node] + self.global_statements)
217+
return cst.Module(body=[updated_node, *self.global_statements])
212218

213219
return cst.Module(body=[updated_node])
214220

@@ -222,7 +228,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
222228
return updated_node
223229

224230

225-
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
231+
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
226232
"""Extract global statements from source code."""
227233
module = cst.parse_module(source_code)
228234
collector = GlobalStatementCollector()
@@ -285,8 +291,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
285291
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
286292
transformed_module = original_module.visit(transformer)
287293

288-
dst_module_code = transformed_module.code
289-
return dst_module_code
294+
return transformed_module.code
290295

291296

292297
def add_needed_imports_from_module(
@@ -357,9 +362,10 @@ def add_needed_imports_from_module(
357362

358363

359364
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
360-
"""Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
361-
FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the
362-
module level, or it represents a list of methods of the same class.
365+
"""Return the code for a function or methods in a Python module.
366+
367+
functions_to_optimize is either a singleton FunctionToOptimize instance, which represents either a function at the
368+
module level or a method of a class at the module level, or it represents a list of methods of the same class.
363369
"""
364370
if (
365371
not functions_to_optimize
@@ -427,7 +433,7 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
427433

428434
return find_target(target.body, name_parts[1:])
429435

430-
with open(file_path, encoding="utf8") as file:
436+
with file_path.open(encoding="utf8") as file:
431437
source_code: str = file.read()
432438
try:
433439
module_node: ast.Module = ast.parse(source_code)

codeflash/code_utils/code_replacer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
8282

8383
return True
8484

85-
def leave_ClassDef(self, node: cst.ClassDef) -> None:
85+
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
8686
if self.current_class:
8787
self.current_class = None
8888

@@ -104,7 +104,7 @@ def __init__(
104104
)
105105
self.current_class = None
106106

107-
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
107+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
108108
return False
109109

110110
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
@@ -133,7 +133,7 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
133133
)
134134
return updated_node
135135

136-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
136+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
137137
node = updated_node
138138
max_function_index = None
139139
class_index = None

codeflash/code_utils/shell_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
from __future__ import annotations
2+
13
import os
24
import re
35
from pathlib import Path
4-
from typing import Optional
6+
from typing import TYPE_CHECKING, Optional
57

68
from codeflash.code_utils.compat import LF
7-
from codeflash.either import Failure, Result, Success
9+
from codeflash.either import Failure, Success
10+
11+
if TYPE_CHECKING:
12+
from codeflash.either import Result
813

914
if os.name == "nt": # Windows
1015
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
@@ -17,7 +22,7 @@
1722
def read_api_key_from_shell_config() -> Optional[str]:
1823
try:
1924
shell_rc_path = get_shell_rc_path()
20-
with open(shell_rc_path, encoding="utf8") as shell_rc:
25+
with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123
2126
shell_contents = shell_rc.read()
2227
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
2328
return matches[-1] if matches else None
@@ -40,15 +45,14 @@ def get_api_key_export_line(api_key: str) -> str:
4045
return f"{SHELL_RC_EXPORT_PREFIX}{api_key}"
4146

4247

43-
def save_api_key_to_rc(api_key) -> Result[str, str]:
48+
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
4449
shell_rc_path = get_shell_rc_path()
4550
api_key_line = get_api_key_export_line(api_key)
4651
try:
47-
with open(shell_rc_path, "r+", encoding="utf8") as shell_file:
52+
with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123
4853
shell_contents = shell_file.read()
49-
if os.name == "nt": # on Windows, we're writing a batch file
50-
if not shell_contents:
51-
shell_contents = "@echo off"
54+
if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file
55+
shell_contents = "@echo off"
5256
existing_api_key = read_api_key_from_shell_config()
5357

5458
if existing_api_key:

0 commit comments

Comments
 (0)