Skip to content

Commit 9fb36ae

Browse files
committed
squash
1 parent b977e3c commit 9fb36ae

File tree

1 file changed

+12
-19
lines changed

1 file changed

+12
-19
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -455,39 +455,32 @@ def parse_code_and_prune_cst(
455455
code: str,
456456
code_context_type: CodeContextType,
457457
target_functions: set[str],
458-
helpers_of_helper_functions: Optional[set[str]] = None,
458+
helpers_of_helper_functions: set[str] = set(),
459459
remove_docstrings: bool = False,
460460
) -> str:
461461
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
462-
if helpers_of_helper_functions is None:
463-
helpers_of_helper_functions = set()
464462
module = cst.parse_module(code)
465-
466-
dispatch: dict[
467-
CodeContextType, Callable[[cst.CSTNode, set[str], set[str], str, bool], tuple[Optional[cst.CSTNode], bool]]
468-
] = {
469-
CodeContextType.READ_WRITABLE: prune_cst_for_read_writable_code,
470-
CodeContextType.READ_ONLY: prune_cst_for_read_only_code,
471-
CodeContextType.TESTGEN: prune_cst_for_testgen_code,
472-
}
473-
474-
prune_func = dispatch.get(code_context_type)
475-
if prune_func is None:
463+
if code_context_type == CodeContextType.READ_WRITABLE:
464+
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
465+
elif code_context_type == CodeContextType.READ_ONLY:
466+
filtered_node, found_target = prune_cst_for_read_only_code(
467+
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
468+
)
469+
elif code_context_type == CodeContextType.TESTGEN:
470+
filtered_node, found_target = prune_cst_for_testgen_code(
471+
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
472+
)
473+
else:
476474
msg = f"Unknown code_context_type: {code_context_type}"
477475
raise ValueError(msg)
478476

479-
filtered_node, found_target = prune_func(
480-
module, target_functions, helpers_of_helper_functions, "", remove_docstrings
481-
)
482-
483477
if not found_target:
484478
msg = "No target functions found in the provided code"
485479
raise ValueError(msg)
486480
if filtered_node and isinstance(filtered_node, cst.Module):
487481
return str(filtered_node.code)
488482
return ""
489483

490-
491484
def prune_cst_for_read_writable_code(
492485
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
493486
) -> tuple[cst.CSTNode | None, bool]:

0 commit comments

Comments
 (0)