Skip to content

Commit f0826bb

Browse files
committed
Update code_context_extractor.py
1 parent 6641f3d commit f0826bb

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def extract_code_string_context_from_files(
142142
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
143143
) -> CodeString:
144144
"""Extract code context from files containing target functions and their helpers.
145+
145146
This function processes two sets of files:
146147
1. Files containing the function to optimize (fto) and their first-degree helpers
147148
2. Files containing only helpers of helpers (with no overlap with the first set)
@@ -162,12 +163,12 @@ def extract_code_string_context_from_files(
162163
"""
163164
# Rearrange to remove overlaps, so we only access each file path once
164165
helpers_of_helpers_no_overlap = defaultdict(set)
165-
for file_path in helpers_of_helpers:
166+
for file_path, helper_functions in helpers_of_helpers.items():
166167
if file_path in helpers_of_fto:
167168
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
168169
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
169170
else:
170-
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
171+
helpers_of_helpers_no_overlap[file_path] = helper_functions
171172

172173
final_code_string_context = ""
173174

@@ -372,21 +373,19 @@ def get_function_to_optimize_as_function_source(
372373
and name.full_name.startswith(name.module_name)
373374
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
374375
):
375-
function_source = FunctionSource(
376+
return FunctionSource(
376377
file_path=function_to_optimize.file_path,
377378
qualified_name=function_to_optimize.qualified_name,
378379
fully_qualified_name=name.full_name,
379380
only_function_name=name.name,
380381
source_code=name.get_line_code(),
381382
jedi_definition=name,
382383
)
383-
return function_source
384-
except Exception as e:
384+
except Exception as e: # noqa: PERF203
385385
logger.exception(f"Error while getting function source: {e}")
386386
continue
387-
raise ValueError(
388-
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
389-
)
387+
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
388+
raise ValueError(msg)
390389

391390

392391
def get_function_sources_from_jedi(
@@ -450,13 +449,13 @@ def is_dunder_method(name: str) -> bool:
450449

451450

452451
def get_section_names(node: cst.CSTNode) -> list[str]:
453-
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist."""
452+
"""Return the section attribute names (e.g., body, orelse) for a given node if they exist."""
454453
possible_sections = ["body", "orelse", "finalbody", "handlers"]
455454
return [sec for sec in possible_sections if hasattr(node, sec)]
456455

457456

458457
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
459-
"""Removes the docstring from an indented block if it exists"""
458+
"""Remove the docstring from an indented block if it exists."""
460459
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
461460
return indented_block
462461
first_stmt = indented_block.body[0].body[0]
@@ -469,10 +468,12 @@ def parse_code_and_prune_cst(
469468
code: str,
470469
code_context_type: CodeContextType,
471470
target_functions: set[str],
472-
helpers_of_helper_functions: set[str] = set(),
473-
remove_docstrings: bool = False,
471+
helpers_of_helper_functions: set[str] | None = None,
472+
remove_docstrings: bool = False, # noqa: FBT001, FBT002
474473
) -> str:
475474
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
475+
if helpers_of_helper_functions is None:
476+
helpers_of_helper_functions = set()
476477
module = cst.parse_module(code)
477478
if code_context_type == CodeContextType.READ_WRITABLE:
478479
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
@@ -485,7 +486,8 @@ def parse_code_and_prune_cst(
485486
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
486487
)
487488
else:
488-
raise ValueError(f"Unknown code_context_type: {code_context_type}")
489+
msg = f"Unknown code_context_type: {code_context_type}"
490+
raise TypeError(msg)
489491

490492
if not found_target:
491493
raise ValueError("No target functions found in the provided code")
@@ -494,7 +496,7 @@ def parse_code_and_prune_cst(
494496
return ""
495497

496498

497-
def prune_cst_for_read_writable_code(
499+
def prune_cst_for_read_writable_code( # noqa: PLR0911
498500
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
499501
) -> tuple[cst.CSTNode | None, bool]:
500502
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@@ -520,7 +522,7 @@ def prune_cst_for_read_writable_code(
520522
return None, False
521523
# Assuming always an IndentedBlock
522524
if not isinstance(node.body, cst.IndentedBlock):
523-
raise ValueError("ClassDef body is not an IndentedBlock")
525+
raise TypeError("ClassDef body is not an IndentedBlock")
524526
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
525527
new_body = []
526528
found_target = False
@@ -574,14 +576,14 @@ def prune_cst_for_read_writable_code(
574576
return (node.with_changes(**updates) if updates else node), True
575577

576578

577-
def prune_cst_for_read_only_code(
579+
def prune_cst_for_read_only_code( # noqa: PLR0911
578580
node: cst.CSTNode,
579581
target_functions: set[str],
580582
helpers_of_helper_functions: set[str],
581583
prefix: str = "",
582-
remove_docstrings: bool = False,
584+
remove_docstrings: bool = False, # noqa: FBT001, FBT002
583585
) -> tuple[cst.CSTNode | None, bool]:
584-
"""Recursively filter the node for read-only context:
586+
"""Recursively filter the node for read-only context.
585587
586588
Returns:
587589
(filtered_node, found_target):
@@ -613,7 +615,7 @@ def prune_cst_for_read_only_code(
613615
return None, False
614616
# Assuming always an IndentedBlock
615617
if not isinstance(node.body, cst.IndentedBlock):
616-
raise ValueError("ClassDef body is not an IndentedBlock")
618+
raise TypeError("ClassDef body is not an IndentedBlock")
617619

618620
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
619621

@@ -678,14 +680,14 @@ def prune_cst_for_read_only_code(
678680
return None, False
679681

680682

681-
def prune_cst_for_testgen_code(
683+
def prune_cst_for_testgen_code( # noqa: PLR0911
682684
node: cst.CSTNode,
683685
target_functions: set[str],
684686
helpers_of_helper_functions: set[str],
685687
prefix: str = "",
686-
remove_docstrings: bool = False,
688+
remove_docstrings: bool = False, # noqa: FBT001, FBT002
687689
) -> tuple[cst.CSTNode | None, bool]:
688-
"""Recursively filter the node for testgen context:
690+
"""Recursively filter the node for testgen context.
689691
690692
Returns:
691693
(filtered_node, found_target):
@@ -718,7 +720,7 @@ def prune_cst_for_testgen_code(
718720
return None, False
719721
# Assuming always an IndentedBlock
720722
if not isinstance(node.body, cst.IndentedBlock):
721-
raise ValueError("ClassDef body is not an IndentedBlock")
723+
raise TypeError("ClassDef body is not an IndentedBlock")
722724

723725
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
724726

0 commit comments

Comments
 (0)