Skip to content

Commit 8debda7

Browse files
committed
Update code_context_extractor.py
1 parent dd7cbb4 commit 8debda7

File tree

1 file changed

+36
-43
lines changed

1 file changed

+36
-43
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import os
44
from collections import defaultdict
55
from itertools import chain
6-
from typing import TYPE_CHECKING, Optional
6+
from pathlib import Path
77

88
import jedi
99
import libcst as cst
1010
import tiktoken
11+
from jedi.api.classes import Name
12+
from libcst import CSTNode
1113

1214
from codeflash.cli_cmds.console import logger
1315
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1416
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
1517
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
18+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1619
from codeflash.models.models import (
1720
CodeContextType,
1821
CodeOptimizationContext,
@@ -22,14 +25,6 @@
2225
)
2326
from codeflash.optimization.function_context import belongs_to_function_qualified
2427

25-
if TYPE_CHECKING:
26-
from pathlib import Path
27-
28-
from jedi.api.classes import Name
29-
from libcst import CSTNode
30-
31-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
32-
3328

3429
def get_code_optimization_context(
3530
function_to_optimize: FunctionToOptimize,
@@ -143,14 +138,13 @@ def extract_code_string_context_from_files(
143138
helpers_of_fto: dict[Path, set[FunctionSource]],
144139
helpers_of_helpers: dict[Path, set[FunctionSource]],
145140
project_root_path: Path,
146-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
141+
remove_docstrings: bool = False,
147142
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
148143
) -> CodeString:
149144
"""Extract code context from files containing target functions and their helpers.
150-
151145
This function processes two sets of files:
152146
1. Files containing the function to optimize (fto) and their first-degree helpers
153-
2. Files containing only helpers of helpers (with no overlap with the first set).
147+
2. Files containing only helpers of helpers (with no overlap with the first set)
154148
155149
For each file, it extracts relevant code based on the specified context type, adds necessary
156150
imports, and combines them.
@@ -168,12 +162,12 @@ def extract_code_string_context_from_files(
168162
"""
169163
# Rearrange to remove overlaps, so we only access each file path once
170164
helpers_of_helpers_no_overlap = defaultdict(set)
171-
for file_path, helper_set in helpers_of_helpers.items():
165+
for file_path in helpers_of_helpers:
172166
if file_path in helpers_of_fto:
173167
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
174-
helpers_of_helpers_no_overlap[file_path] = helper_set - helpers_of_fto[file_path]
168+
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
175169
else:
176-
helpers_of_helpers_no_overlap[file_path] = helper_set
170+
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
177171

178172
final_code_string_context = ""
179173

@@ -250,7 +244,7 @@ def extract_code_markdown_context_from_files(
250244
helpers_of_fto: dict[Path, set[FunctionSource]],
251245
helpers_of_helpers: dict[Path, set[FunctionSource]],
252246
project_root_path: Path,
253-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
247+
remove_docstrings: bool = False,
254248
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
255249
) -> CodeStringsMarkdown:
256250
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
@@ -276,12 +270,12 @@ def extract_code_markdown_context_from_files(
276270
"""
277271
# Rearrange to remove overlaps, so we only access each file path once
278272
helpers_of_helpers_no_overlap = defaultdict(set)
279-
for file_path, helper_set in helpers_of_helpers.items():
273+
for file_path in helpers_of_helpers:
280274
if file_path in helpers_of_fto:
281275
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
282-
helpers_of_helpers_no_overlap[file_path] = helper_set - helpers_of_fto[file_path]
276+
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
283277
else:
284-
helpers_of_helpers_no_overlap[file_path] = helper_set
278+
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
285279
code_context_markdown = CodeStringsMarkdown()
286280
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
287281
for file_path, function_sources in helpers_of_fto.items():
@@ -378,19 +372,21 @@ def get_function_to_optimize_as_function_source(
378372
and name.full_name.startswith(name.module_name)
379373
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
380374
):
381-
return FunctionSource(
375+
function_source = FunctionSource(
382376
file_path=function_to_optimize.file_path,
383377
qualified_name=function_to_optimize.qualified_name,
384378
fully_qualified_name=name.full_name,
385379
only_function_name=name.name,
386380
source_code=name.get_line_code(),
387381
jedi_definition=name,
388382
)
389-
except Exception as e: # noqa: PERF203
383+
return function_source
384+
except Exception as e:
390385
logger.exception(f"Error while getting function source: {e}")
391386
continue
392-
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
393-
raise ValueError(msg)
387+
raise ValueError(
388+
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
389+
)
394390

395391

396392
def get_function_sources_from_jedi(
@@ -411,7 +407,7 @@ def get_function_sources_from_jedi(
411407
for name in names:
412408
try:
413409
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
414-
except Exception:
410+
except Exception: # noqa: BLE001
415411
logger.debug(f"Error while getting definitions for {qualified_function_name}")
416412
definitions = []
417413
if definitions:
@@ -454,13 +450,13 @@ def is_dunder_method(name: str) -> bool:
454450

455451

456452
def get_section_names(node: cst.CSTNode) -> list[str]:
457-
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401
453+
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist."""
458454
possible_sections = ["body", "orelse", "finalbody", "handlers"]
459455
return [sec for sec in possible_sections if hasattr(node, sec)]
460456

461457

462458
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
463-
"""Remove the docstring from an indented block if it exists."""
459+
"""Removes the docstring from an indented block if it exists"""
464460
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
465461
return indented_block
466462
first_stmt = indented_block.body[0].body[0]
@@ -473,12 +469,10 @@ def parse_code_and_prune_cst(
473469
code: str,
474470
code_context_type: CodeContextType,
475471
target_functions: set[str],
476-
helpers_of_helper_functions: Optional[set[str]] = None,
477-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
472+
helpers_of_helper_functions: set[str] = set(),
473+
remove_docstrings: bool = False,
478474
) -> str:
479475
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
480-
if helpers_of_helper_functions is None:
481-
helpers_of_helper_functions = set()
482476
module = cst.parse_module(code)
483477
if code_context_type == CodeContextType.READ_WRITABLE:
484478
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
@@ -491,8 +485,7 @@ def parse_code_and_prune_cst(
491485
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
492486
)
493487
else:
494-
msg = f"Unknown code_context_type: {code_context_type}"
495-
raise ValueError(msg)
488+
raise ValueError(f"Unknown code_context_type: {code_context_type}")
496489

497490
if not found_target:
498491
raise ValueError("No target functions found in the provided code")
@@ -501,7 +494,7 @@ def parse_code_and_prune_cst(
501494
return ""
502495

503496

504-
def prune_cst_for_read_writable_code( # noqa: PLR0911
497+
def prune_cst_for_read_writable_code(
505498
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
506499
) -> tuple[cst.CSTNode | None, bool]:
507500
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@@ -527,7 +520,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
527520
return None, False
528521
# Assuming always an IndentedBlock
529522
if not isinstance(node.body, cst.IndentedBlock):
530-
raise TypeError("ClassDef body is not an IndentedBlock")
523+
raise ValueError("ClassDef body is not an IndentedBlock")
531524
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
532525
new_body = []
533526
found_target = False
@@ -581,14 +574,14 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
581574
return (node.with_changes(**updates) if updates else node), True
582575

583576

584-
def prune_cst_for_read_only_code( # noqa: PLR0911
577+
def prune_cst_for_read_only_code(
585578
node: cst.CSTNode,
586579
target_functions: set[str],
587580
helpers_of_helper_functions: set[str],
588581
prefix: str = "",
589-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
582+
remove_docstrings: bool = False,
590583
) -> tuple[cst.CSTNode | None, bool]:
591-
"""Recursively filter the node for read-only context.
584+
"""Recursively filter the node for read-only context:
592585
593586
Returns:
594587
(filtered_node, found_target):
@@ -620,7 +613,7 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
620613
return None, False
621614
# Assuming always an IndentedBlock
622615
if not isinstance(node.body, cst.IndentedBlock):
623-
raise TypeError("ClassDef body is not an IndentedBlock")
616+
raise ValueError("ClassDef body is not an IndentedBlock")
624617

625618
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
626619

@@ -685,14 +678,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
685678
return None, False
686679

687680

688-
def prune_cst_for_testgen_code( # noqa: PLR0911
681+
def prune_cst_for_testgen_code(
689682
node: cst.CSTNode,
690683
target_functions: set[str],
691684
helpers_of_helper_functions: set[str],
692685
prefix: str = "",
693-
remove_docstrings: bool = False, # noqa: FBT001, FBT002
686+
remove_docstrings: bool = False,
694687
) -> tuple[cst.CSTNode | None, bool]:
695-
"""Recursively filter the node for testgen context.
688+
"""Recursively filter the node for testgen context:
696689
697690
Returns:
698691
(filtered_node, found_target):
@@ -725,7 +718,7 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
725718
return None, False
726719
# Assuming always an IndentedBlock
727720
if not isinstance(node.body, cst.IndentedBlock):
728-
raise TypeError("ClassDef body is not an IndentedBlock")
721+
raise ValueError("ClassDef body is not an IndentedBlock")
729722

730723
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
731724

@@ -787,4 +780,4 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
787780
if updates:
788781
return (node.with_changes(**updates), found_any_target)
789782

790-
return None, False
783+
return None, False

0 commit comments

Comments
 (0)