@@ -375,9 +375,7 @@ def get_function_to_optimize_as_function_source(
375375            )
376376
377377    msg  =  f"Could not find function { function_to_optimize .function_name }   in { function_to_optimize .file_path }  " 
378-     raise  ValueError (
379-         msg 
380-     )
378+     raise  ValueError (msg )
381379
382380
383381def  get_function_sources_from_jedi (
@@ -463,20 +461,22 @@ def parse_code_and_prune_cst(
463461    if  helpers_of_helper_functions  is  None :
464462        helpers_of_helper_functions  =  set ()
465463    module  =  cst .parse_module (code )
466-     if  code_context_type  ==  CodeContextType .READ_WRITABLE :
467-         filtered_node , found_target  =  prune_cst_for_read_writable_code (module , target_functions )
468-     elif  code_context_type  ==  CodeContextType .READ_ONLY :
469-         filtered_node , found_target  =  prune_cst_for_read_only_code (
470-             module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings 
471-         )
472-     elif  code_context_type  ==  CodeContextType .TESTGEN :
473-         filtered_node , found_target  =  prune_cst_for_testgen_code (
474-             module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings 
475-         )
476-     else :
464+ 
465+     dispatch  =  {
466+         CodeContextType .READ_WRITABLE : prune_cst_for_read_writable_code ,
467+         CodeContextType .READ_ONLY : prune_cst_for_read_only_code ,
468+         CodeContextType .TESTGEN : prune_cst_for_testgen_code ,
469+     }
470+ 
471+     prune_func  =  dispatch .get (code_context_type )
472+     if  prune_func  is  None :
477473        msg  =  f"Unknown code_context_type: { code_context_type }  " 
478474        raise  ValueError (msg )
479475
476+     filtered_node , found_target  =  prune_func (
477+         module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings 
478+     )
479+ 
480480    if  not  found_target :
481481        msg  =  "No target functions found in the provided code" 
482482        raise  ValueError (msg )
@@ -625,10 +625,12 @@ def prune_cst_for_read_only_code(
625625            return  None , False 
626626
627627        if  remove_docstrings :
628-             return  node .with_changes (
629-                 body = remove_docstring_from_body (node .body .with_changes (body = new_class_body ))
630-             ) if  new_class_body  else  None , True 
631-         return  node .with_changes (body = node .body .with_changes (body = new_class_body )) if  new_class_body  else  None , True 
628+             return  (
629+                 node .with_changes (body = remove_docstring_from_body (node .body .with_changes (body = new_class_body )))
630+                 if  new_class_body 
631+                 else  None 
632+             ), True 
633+         return  (node .with_changes (body = node .body .with_changes (body = new_class_body )) if  new_class_body  else  None ), True 
632634
633635    # For other nodes, keep the node and recursively filter children 
634636    section_names  =  get_section_names (node )
@@ -731,10 +733,12 @@ def prune_cst_for_testgen_code(
731733            return  None , False 
732734
733735        if  remove_docstrings :
734-             return  node .with_changes (
735-                 body = remove_docstring_from_body (node .body .with_changes (body = new_class_body ))
736-             ) if  new_class_body  else  None , True 
737-         return  node .with_changes (body = node .body .with_changes (body = new_class_body )) if  new_class_body  else  None , True 
736+             return  (
737+                 node .with_changes (body = remove_docstring_from_body (node .body .with_changes (body = new_class_body )))
738+                 if  new_class_body 
739+                 else  None 
740+             ), True 
741+         return  (node .with_changes (body = node .body .with_changes (body = new_class_body )) if  new_class_body  else  None ), True 
738742
739743    # For other nodes, keep the node and recursively filter children 
740744    section_names  =  get_section_names (node )
0 commit comments