Skip to content

Commit 0ba52ea

Browse files
committed
go
1 parent 3a3e4db commit 0ba52ea

File tree

3 files changed

+28
-35
lines changed

3 files changed

+28
-35
lines changed

codeflash/api/aiservice.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,9 @@ def generate_regression_tests( # noqa: D417
270270
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
271271
272272
"""
273-
assert test_framework in [
274-
"pytest",
275-
"unittest",
276-
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
273+
assert test_framework in ["pytest", "unittest"], (
274+
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
275+
)
277276
payload = {
278277
"source_code_being_tested": source_code_being_tested,
279278
"function_to_optimize": function_to_optimize,

codeflash/code_utils/code_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212

1313

1414
def encoded_tokens_len(s: str) -> int:
15-
'''Function for returning the approximate length of the encoded tokens
16-
It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)'''
17-
return int(len(s)*0.25)
15+
"""Return the approximate length of the encoded tokens.
16+
17+
It's an approximation of BPE encoding (https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).
18+
"""
19+
return int(len(s) * 0.25)
20+
1821

1922
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
2023
if not full_qualified_name:

codeflash/context/code_context_extractor.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
import os
44
from collections import defaultdict
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Optional
6+
from pathlib import Path # noqa: TC003
77

88
import jedi
99
import libcst as cst
10-
from jedi.api.classes import Name
11-
from libcst import CSTNode
10+
from jedi.api.classes import Name # noqa: TC002
11+
from libcst import CSTNode # noqa: TC002
1212

1313
from codeflash.cli_cmds.console import logger
1414
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1515
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
1616
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
17+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
1718
from codeflash.models.models import (
1819
CodeContextType,
1920
CodeOptimizationContext,
@@ -23,14 +24,6 @@
2324
)
2425
from codeflash.optimization.function_context import belongs_to_function_qualified
2526

26-
if TYPE_CHECKING:
27-
from pathlib import Path
28-
29-
from jedi.api.classes import Name
30-
from libcst import CSTNode
31-
32-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
33-
3427

3528
def get_code_optimization_context(
3629
function_to_optimize: FunctionToOptimize,
@@ -147,7 +140,6 @@ def extract_code_string_context_from_files(
147140
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
148141
) -> CodeString:
149142
"""Extract code context from files containing target functions and their helpers.
150-
151143
This function processes two sets of files:
152144
1. Files containing the function to optimize (fto) and their first-degree helpers
153145
2. Files containing only helpers of helpers (with no overlap with the first set).
@@ -165,15 +157,15 @@ def extract_code_string_context_from_files(
165157
Returns:
166158
CodeString containing the extracted code context with necessary imports
167159
168-
"""
160+
""" # noqa: D205
169161
# Rearrange to remove overlaps, so we only access each file path once
170162
helpers_of_helpers_no_overlap = defaultdict(set)
171-
for file_path, helper_set in helpers_of_helpers.items():
163+
for file_path, function_sources in helpers_of_helpers.items():
172164
if file_path in helpers_of_fto:
173165
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
174-
helpers_of_helpers_no_overlap[file_path] = helper_set - helpers_of_fto[file_path]
166+
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
175167
else:
176-
helpers_of_helpers_no_overlap[file_path] = helper_set
168+
helpers_of_helpers_no_overlap[file_path] = function_sources
177169

178170
final_code_string_context = ""
179171

@@ -276,12 +268,12 @@ def extract_code_markdown_context_from_files(
276268
"""
277269
# Rearrange to remove overlaps, so we only access each file path once
278270
helpers_of_helpers_no_overlap = defaultdict(set)
279-
for file_path, helper_set in helpers_of_helpers.items():
271+
for file_path, function_sources in helpers_of_helpers.items():
280272
if file_path in helpers_of_fto:
281273
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
282-
helpers_of_helpers_no_overlap[file_path] = helper_set - helpers_of_fto[file_path]
274+
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
283275
else:
284-
helpers_of_helpers_no_overlap[file_path] = helper_set
276+
helpers_of_helpers_no_overlap[file_path] = function_sources
285277
code_context_markdown = CodeStringsMarkdown()
286278
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
287279
for file_path, function_sources in helpers_of_fto.items():
@@ -389,8 +381,9 @@ def get_function_to_optimize_as_function_source(
389381
except Exception as e: # noqa: PERF203
390382
logger.exception(f"Error while getting function source: {e}")
391383
continue
392-
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
393-
raise ValueError(msg)
384+
raise ValueError(
385+
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" # noqa: EM102
386+
)
394387

395388

396389
def get_function_sources_from_jedi(
@@ -460,7 +453,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
460453

461454

462455
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
463-
"""Remove the docstring from an indented block if it exists."""
456+
"""Removes the docstring from an indented block if it exists.""" # noqa: D401
464457
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
465458
return indented_block
466459
first_stmt = indented_block.body[0].body[0]
@@ -473,12 +466,10 @@ def parse_code_and_prune_cst(
473466
code: str,
474467
code_context_type: CodeContextType,
475468
target_functions: set[str],
476-
helpers_of_helper_functions: Optional[set[str]] = None,
469+
helpers_of_helper_functions: set[str] = set(), # noqa: B006
477470
remove_docstrings: bool = False, # noqa: FBT001, FBT002
478471
) -> str:
479472
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
480-
if helpers_of_helper_functions is None:
481-
helpers_of_helper_functions = set()
482473
module = cst.parse_module(code)
483474
if code_context_type == CodeContextType.READ_WRITABLE:
484475
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
@@ -526,7 +517,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
526517
return None, False
527518
# Assuming always an IndentedBlock
528519
if not isinstance(node.body, cst.IndentedBlock):
529-
raise TypeError("ClassDef body is not an IndentedBlock")
520+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
530521
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
531522
new_body = []
532523
found_target = False
@@ -619,7 +610,7 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
619610
return None, False
620611
# Assuming always an IndentedBlock
621612
if not isinstance(node.body, cst.IndentedBlock):
622-
raise TypeError("ClassDef body is not an IndentedBlock")
613+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
623614

624615
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
625616

@@ -724,7 +715,7 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
724715
return None, False
725716
# Assuming always an IndentedBlock
726717
if not isinstance(node.body, cst.IndentedBlock):
727-
raise TypeError("ClassDef body is not an IndentedBlock")
718+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
728719

729720
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
730721

0 commit comments

Comments
 (0)