@@ -455,39 +455,32 @@ def parse_code_and_prune_cst(
455455 code : str ,
456456 code_context_type : CodeContextType ,
457457 target_functions : set [str ],
458- helpers_of_helper_functions : Optional [ set [str ]] = None ,
458+ helpers_of_helper_functions : set [str ] = set () ,
459459 remove_docstrings : bool = False ,
460460) -> str :
461461 """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
462- if helpers_of_helper_functions is None :
463- helpers_of_helper_functions = set ()
464462 module = cst .parse_module (code )
465-
466- dispatch : dict [
467- CodeContextType , Callable [[ cst . CSTNode , set [ str ], set [ str ], str , bool ], tuple [ Optional [ cst . CSTNode ], bool ]]
468- ] = {
469- CodeContextType . READ_WRITABLE : prune_cst_for_read_writable_code ,
470- CodeContextType . READ_ONLY : prune_cst_for_read_only_code ,
471- CodeContextType .TESTGEN : prune_cst_for_testgen_code ,
472- }
473-
474- prune_func = dispatch . get ( code_context_type )
475- if prune_func is None :
463+ if code_context_type == CodeContextType . READ_WRITABLE :
464+ filtered_node , found_target = prune_cst_for_read_writable_code ( module , target_functions )
465+ elif code_context_type == CodeContextType . READ_ONLY :
466+ filtered_node , found_target = prune_cst_for_read_only_code (
467+ module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
468+ )
469+ elif code_context_type == CodeContextType .TESTGEN :
470+ filtered_node , found_target = prune_cst_for_testgen_code (
471+ module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
472+ )
473+ else :
476474 msg = f"Unknown code_context_type: { code_context_type } "
477475 raise ValueError (msg )
478476
479- filtered_node , found_target = prune_func (
480- module , target_functions , helpers_of_helper_functions , "" , remove_docstrings
481- )
482-
483477 if not found_target :
484478 msg = "No target functions found in the provided code"
485479 raise ValueError (msg )
486480 if filtered_node and isinstance (filtered_node , cst .Module ):
487481 return str (filtered_node .code )
488482 return ""
489483
490-
491484def prune_cst_for_read_writable_code (
492485 node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
493486) -> tuple [cst .CSTNode | None , bool ]:
0 commit comments