Skip to content

Commit 5862067

Browse files
committed
Update code_context_extractor.py
1 parent 8be54d1 commit 5862067

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,7 @@ def get_function_to_optimize_as_function_source(
375375
)
376376

377377
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
378-
raise ValueError(
379-
msg
380-
)
378+
raise ValueError(msg)
381379

382380

383381
def get_function_sources_from_jedi(
@@ -463,20 +461,22 @@ def parse_code_and_prune_cst(
463461
if helpers_of_helper_functions is None:
464462
helpers_of_helper_functions = set()
465463
module = cst.parse_module(code)
466-
if code_context_type == CodeContextType.READ_WRITABLE:
467-
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
468-
elif code_context_type == CodeContextType.READ_ONLY:
469-
filtered_node, found_target = prune_cst_for_read_only_code(
470-
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
471-
)
472-
elif code_context_type == CodeContextType.TESTGEN:
473-
filtered_node, found_target = prune_cst_for_testgen_code(
474-
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
475-
)
476-
else:
464+
465+
dispatch = {
466+
CodeContextType.READ_WRITABLE: prune_cst_for_read_writable_code,
467+
CodeContextType.READ_ONLY: prune_cst_for_read_only_code,
468+
CodeContextType.TESTGEN: prune_cst_for_testgen_code,
469+
}
470+
471+
prune_func = dispatch.get(code_context_type)
472+
if prune_func is None:
477473
msg = f"Unknown code_context_type: {code_context_type}"
478474
raise ValueError(msg)
479475

476+
filtered_node, found_target = prune_func(
477+
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
478+
)
479+
480480
if not found_target:
481481
msg = "No target functions found in the provided code"
482482
raise ValueError(msg)
@@ -625,10 +625,12 @@ def prune_cst_for_read_only_code(
625625
return None, False
626626

627627
if remove_docstrings:
628-
return node.with_changes(
629-
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
630-
) if new_class_body else None, True
631-
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
628+
return (
629+
node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)))
630+
if new_class_body
631+
else None
632+
), True
633+
return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True
632634

633635
# For other nodes, keep the node and recursively filter children
634636
section_names = get_section_names(node)
@@ -731,10 +733,12 @@ def prune_cst_for_testgen_code(
731733
return None, False
732734

733735
if remove_docstrings:
734-
return node.with_changes(
735-
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
736-
) if new_class_body else None, True
737-
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
736+
return (
737+
node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)))
738+
if new_class_body
739+
else None
740+
), True
741+
return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True
738742

739743
# For other nodes, keep the node and recursively filter children
740744
section_names = get_section_names(node)

0 commit comments

Comments
 (0)