11from  __future__ import  annotations 
22
3+ import  hashlib 
34import  os 
45from  collections  import  defaultdict 
56from  itertools  import  chain 
6- from  typing  import  TYPE_CHECKING 
7+ from  typing  import  TYPE_CHECKING ,  cast 
78
89import  libcst  as  cst 
910
3132def  get_code_optimization_context (
3233    function_to_optimize : FunctionToOptimize ,
3334    project_root_path : Path ,
34-     optim_token_limit : int  =  8000 ,
35-     testgen_token_limit : int  =  8000 ,
35+     optim_token_limit : int  =  16000 ,
36+     testgen_token_limit : int  =  16000 ,
3637) ->  CodeOptimizationContext :
3738    # Get FunctionSource representation of helpers of FTO 
3839    helpers_of_fto_dict , helpers_of_fto_list  =  get_function_sources_from_jedi (
@@ -73,6 +74,13 @@ def get_code_optimization_context(
7374        remove_docstrings = False ,
7475        code_context_type = CodeContextType .READ_ONLY ,
7576    )
77+     hashing_code_context  =  extract_code_markdown_context_from_files (
78+         helpers_of_fto_dict ,
79+         helpers_of_helpers_dict ,
80+         project_root_path ,
81+         remove_docstrings = True ,
82+         code_context_type = CodeContextType .HASHING ,
83+     )
7684
7785    # Handle token limits 
7886    final_read_writable_tokens  =  encoded_tokens_len (final_read_writable_code )
@@ -125,11 +133,15 @@ def get_code_optimization_context(
125133        testgen_context_code_tokens  =  encoded_tokens_len (testgen_context_code )
126134        if  testgen_context_code_tokens  >  testgen_token_limit :
127135            raise  ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
136+     code_hash_context  =  hashing_code_context .markdown 
137+     code_hash  =  hashlib .sha256 (code_hash_context .encode ("utf-8" )).hexdigest ()
128138
129139    return  CodeOptimizationContext (
130140        testgen_context_code = testgen_context_code ,
131141        read_writable_code = final_read_writable_code ,
132142        read_only_context_code = read_only_context_code ,
143+         hashing_code_context = code_hash_context ,
144+         hashing_code_context_hash = code_hash ,
133145        helper_functions = helpers_of_fto_list ,
134146        preexisting_objects = preexisting_objects ,
135147    )
@@ -309,8 +321,8 @@ def extract_code_markdown_context_from_files(
309321            logger .debug (f"Error while getting read-only code: { e }  " )
310322            continue 
311323        if  code_context .strip ():
312-             code_context_with_imports   =   CodeString ( 
313-                 code = add_needed_imports_from_module (
324+             if   code_context_type   !=   CodeContextType . HASHING : 
325+                 code_context   =   add_needed_imports_from_module (
314326                    src_module_code = original_code ,
315327                    dst_module_code = code_context ,
316328                    src_path = file_path ,
@@ -319,10 +331,9 @@ def extract_code_markdown_context_from_files(
319331                    helper_functions = list (
320332                        helpers_of_fto .get (file_path , set ()) |  helpers_of_helpers .get (file_path , set ())
321333                    ),
322-                 ),
323-                 file_path = file_path .relative_to (project_root_path ),
324-             )
325-             code_context_markdown .code_strings .append (code_context_with_imports )
334+                 )
335+             code_string_context  =  CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
336+             code_context_markdown .code_strings .append (code_string_context )
326337    # Extract code from file paths containing helpers of helpers 
327338    for  file_path , helper_function_sources  in  helpers_of_helpers_no_overlap .items ():
328339        try :
@@ -343,18 +354,17 @@ def extract_code_markdown_context_from_files(
343354            continue 
344355
345356        if  code_context .strip ():
346-             code_context_with_imports   =   CodeString ( 
347-                 code = add_needed_imports_from_module (
357+             if   code_context_type   !=   CodeContextType . HASHING : 
358+                 code_context   =   add_needed_imports_from_module (
348359                    src_module_code = original_code ,
349360                    dst_module_code = code_context ,
350361                    src_path = file_path ,
351362                    dst_path = file_path ,
352363                    project_root = project_root_path ,
353364                    helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
354-                 ),
355-                 file_path = file_path .relative_to (project_root_path ),
356-             )
357-             code_context_markdown .code_strings .append (code_context_with_imports )
365+                 )
366+             code_string_context  =  CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
367+             code_context_markdown .code_strings .append (code_string_context )
358368    return  code_context_markdown 
359369
360370
@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492502        filtered_node , found_target  =  prune_cst_for_testgen_code (
493503            module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings 
494504        )
505+     elif  code_context_type  ==  CodeContextType .HASHING :
506+         filtered_node , found_target  =  prune_cst_for_code_hashing (module , target_functions )
495507    else :
496508        raise  ValueError (f"Unknown code_context_type: { code_context_type }  " )  # noqa: EM102 
497509
@@ -583,6 +595,90 @@ def prune_cst_for_read_writable_code(  # noqa: PLR0911
583595    return  (node .with_changes (** updates ) if  updates  else  node ), True 
584596
585597
598+ def  prune_cst_for_code_hashing (  # noqa: PLR0911 
599+     node : cst .CSTNode , target_functions : set [str ], prefix : str  =  "" 
600+ ) ->  tuple [cst .CSTNode  |  None , bool ]:
601+     """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. 
602+ 
603+     Returns 
604+     ------- 
605+         (filtered_node, found_target): 
606+           filtered_node: The modified CST node or None if it should be removed. 
607+           found_target: True if a target function was found in this node's subtree. 
608+ 
609+     """ 
610+     if  isinstance (node , (cst .Import , cst .ImportFrom )):
611+         return  None , False 
612+ 
613+     if  isinstance (node , cst .FunctionDef ):
614+         qualified_name  =  f"{ prefix }  .{ node .name .value }  "  if  prefix  else  node .name .value 
615+         if  qualified_name  in  target_functions :
616+             new_body  =  remove_docstring_from_body (node .body ) if  isinstance (node .body , cst .IndentedBlock ) else  node .body 
617+             return  node .with_changes (body = new_body ), True 
618+         return  None , False 
619+ 
620+     if  isinstance (node , cst .ClassDef ):
621+         # Do not recurse into nested classes 
622+         if  prefix :
623+             return  None , False 
624+         # Assuming always an IndentedBlock 
625+         if  not  isinstance (node .body , cst .IndentedBlock ):
626+             raise  ValueError ("ClassDef body is not an IndentedBlock" )  # noqa: TRY004 
627+         class_prefix  =  f"{ prefix }  .{ node .name .value }  "  if  prefix  else  node .name .value 
628+         new_class_body : list [cst .CSTNode ] =  []
629+         found_target  =  False 
630+ 
631+         for  stmt  in  node .body .body :
632+             if  isinstance (stmt , cst .FunctionDef ):
633+                 qualified_name  =  f"{ class_prefix }  .{ stmt .name .value }  " 
634+                 if  qualified_name  in  target_functions :
635+                     stmt_with_changes  =  stmt .with_changes (
636+                         body = remove_docstring_from_body (cast ("cst.IndentedBlock" , stmt .body ))
637+                     )
638+                     new_class_body .append (stmt_with_changes )
639+                     found_target  =  True 
640+         # If no target functions found, remove the class entirely 
641+         if  not  new_class_body  or  not  found_target :
642+             return  None , False 
643+         return  node .with_changes (
644+             body = cst .IndentedBlock (cast ("list[cst.BaseStatement]" , new_class_body ))
645+         ) if  new_class_body  else  None , found_target 
646+ 
647+     # For other nodes, we preserve them only if they contain target functions in their children. 
648+     section_names  =  get_section_names (node )
649+     if  not  section_names :
650+         return  node , False 
651+ 
652+     updates : dict [str , list [cst .CSTNode ] |  cst .CSTNode ] =  {}
653+     found_any_target  =  False 
654+ 
655+     for  section  in  section_names :
656+         original_content  =  getattr (node , section , None )
657+         if  isinstance (original_content , (list , tuple )):
658+             new_children  =  []
659+             section_found_target  =  False 
660+             for  child  in  original_content :
661+                 filtered , found_target  =  prune_cst_for_code_hashing (child , target_functions , prefix )
662+                 if  filtered :
663+                     new_children .append (filtered )
664+                 section_found_target  |=  found_target 
665+ 
666+             if  section_found_target :
667+                 found_any_target  =  True 
668+                 updates [section ] =  new_children 
669+         elif  original_content  is  not   None :
670+             filtered , found_target  =  prune_cst_for_code_hashing (original_content , target_functions , prefix )
671+             if  found_target :
672+                 found_any_target  =  True 
673+                 if  filtered :
674+                     updates [section ] =  filtered 
675+ 
676+     if  not  found_any_target :
677+         return  None , False 
678+ 
679+     return  (node .with_changes (** updates ) if  updates  else  node ), True 
680+ 
681+ 
586682def  prune_cst_for_read_only_code (  # noqa: PLR0911 
587683    node : cst .CSTNode ,
588684    target_functions : set [str ],
0 commit comments