11from __future__ import annotations
22
3+ import ast
4+ import hashlib
35import os
46from collections import defaultdict
57from itertools import chain
6- from typing import TYPE_CHECKING
8+ from typing import TYPE_CHECKING , cast
79
810import libcst as cst
911
3133def get_code_optimization_context (
3234 function_to_optimize : FunctionToOptimize ,
3335 project_root_path : Path ,
34- optim_token_limit : int = 8000 ,
35- testgen_token_limit : int = 8000 ,
36+ optim_token_limit : int = 16000 ,
37+ testgen_token_limit : int = 16000 ,
3638) -> CodeOptimizationContext :
3739 # Get FunctionSource representation of helpers of FTO
3840 helpers_of_fto_dict , helpers_of_fto_list = get_function_sources_from_jedi (
@@ -73,6 +75,13 @@ def get_code_optimization_context(
7375 remove_docstrings = False ,
7476 code_context_type = CodeContextType .READ_ONLY ,
7577 )
78+ hashing_code_context = extract_code_markdown_context_from_files (
79+ helpers_of_fto_dict ,
80+ helpers_of_helpers_dict ,
81+ project_root_path ,
82+ remove_docstrings = True ,
83+ code_context_type = CodeContextType .HASHING ,
84+ )
7685
7786 # Handle token limits
7887 final_read_writable_tokens = encoded_tokens_len (final_read_writable_code )
@@ -125,11 +134,15 @@ def get_code_optimization_context(
125134 testgen_context_code_tokens = encoded_tokens_len (testgen_context_code )
126135 if testgen_context_code_tokens > testgen_token_limit :
127136 raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
137+ code_hash_context = hashing_code_context .markdown
138+ code_hash = hashlib .sha256 (code_hash_context .encode ("utf-8" )).hexdigest ()
128139
129140 return CodeOptimizationContext (
130141 testgen_context_code = testgen_context_code ,
131142 read_writable_code = final_read_writable_code ,
132143 read_only_context_code = read_only_context_code ,
144+ hashing_code_context = code_hash_context ,
145+ hashing_code_context_hash = code_hash ,
133146 helper_functions = helpers_of_fto_list ,
134147 preexisting_objects = preexisting_objects ,
135148 )
@@ -309,8 +322,8 @@ def extract_code_markdown_context_from_files(
309322 logger .debug (f"Error while getting read-only code: { e } " )
310323 continue
311324 if code_context .strip ():
312- code_context_with_imports = CodeString (
313- code = add_needed_imports_from_module (
325+ if code_context_type != CodeContextType . HASHING :
326+ code_context = add_needed_imports_from_module (
314327 src_module_code = original_code ,
315328 dst_module_code = code_context ,
316329 src_path = file_path ,
@@ -319,10 +332,9 @@ def extract_code_markdown_context_from_files(
319332 helper_functions = list (
320333 helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
321334 ),
322- ),
323- file_path = file_path .relative_to (project_root_path ),
324- )
325- code_context_markdown .code_strings .append (code_context_with_imports )
335+ )
336+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
337+ code_context_markdown .code_strings .append (code_string_context )
326338 # Extract code from file paths containing helpers of helpers
327339 for file_path , helper_function_sources in helpers_of_helpers_no_overlap .items ():
328340 try :
@@ -343,18 +355,17 @@ def extract_code_markdown_context_from_files(
343355 continue
344356
345357 if code_context .strip ():
346- code_context_with_imports = CodeString (
347- code = add_needed_imports_from_module (
358+ if code_context_type != CodeContextType . HASHING :
359+ code_context = add_needed_imports_from_module (
348360 src_module_code = original_code ,
349361 dst_module_code = code_context ,
350362 src_path = file_path ,
351363 dst_path = file_path ,
352364 project_root = project_root_path ,
353365 helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
354- ),
355- file_path = file_path .relative_to (project_root_path ),
356- )
357- code_context_markdown .code_strings .append (code_context_with_imports )
366+ )
367+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
368+ code_context_markdown .code_strings .append (code_string_context )
358369 return code_context_markdown
359370
360371
@@ -492,13 +503,18 @@ def parse_code_and_prune_cst(
492503 filtered_node , found_target = prune_cst_for_testgen_code (
493504 module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
494505 )
506+ elif code_context_type == CodeContextType .HASHING :
507+ filtered_node , found_target = prune_cst_for_code_hashing (module , target_functions )
495508 else :
496509 raise ValueError (f"Unknown code_context_type: { code_context_type } " ) # noqa: EM102
497510
498511 if not found_target :
499512 raise ValueError ("No target functions found in the provided code" )
500513 if filtered_node and isinstance (filtered_node , cst .Module ):
501- return str (filtered_node .code )
514+ code = str (filtered_node .code )
515+ if code_context_type == CodeContextType .HASHING :
516+ code = ast .unparse (ast .parse (code )) # Makes it standard
517+ return code
502518 return ""
503519
504520
@@ -583,6 +599,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583599 return (node .with_changes (** updates ) if updates else node ), True
584600
585601
602+ def prune_cst_for_code_hashing ( # noqa: PLR0911
603+ node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
604+ ) -> tuple [cst .CSTNode | None , bool ]:
605+ """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
606+
607+ Returns
608+ -------
609+ (filtered_node, found_target):
610+ filtered_node: The modified CST node or None if it should be removed.
611+ found_target: True if a target function was found in this node's subtree.
612+
613+ """
614+ if isinstance (node , (cst .Import , cst .ImportFrom )):
615+ return None , False
616+
617+ if isinstance (node , cst .FunctionDef ):
618+ qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
619+ if qualified_name in target_functions :
620+ new_body = remove_docstring_from_body (node .body ) if isinstance (node .body , cst .IndentedBlock ) else node .body
621+ return node .with_changes (body = new_body ), True
622+ return None , False
623+
624+ if isinstance (node , cst .ClassDef ):
625+ # Do not recurse into nested classes
626+ if prefix :
627+ return None , False
628+ # Assuming always an IndentedBlock
629+ if not isinstance (node .body , cst .IndentedBlock ):
630+ raise ValueError ("ClassDef body is not an IndentedBlock" ) # noqa: TRY004
631+ class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
632+ new_class_body : list [cst .CSTNode ] = []
633+ found_target = False
634+
635+ for stmt in node .body .body :
636+ if isinstance (stmt , cst .FunctionDef ):
637+ qualified_name = f"{ class_prefix } .{ stmt .name .value } "
638+ if qualified_name in target_functions :
639+ stmt_with_changes = stmt .with_changes (
640+ body = remove_docstring_from_body (cast ("cst.IndentedBlock" , stmt .body ))
641+ )
642+ new_class_body .append (stmt_with_changes )
643+ found_target = True
644+ # If no target functions found, remove the class entirely
645+ if not new_class_body or not found_target :
646+ return None , False
647+ return node .with_changes (
648+ body = cst .IndentedBlock (cast ("list[cst.BaseStatement]" , new_class_body ))
649+ ) if new_class_body else None , found_target
650+
651+ # For other nodes, we preserve them only if they contain target functions in their children.
652+ section_names = get_section_names (node )
653+ if not section_names :
654+ return node , False
655+
656+ updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
657+ found_any_target = False
658+
659+ for section in section_names :
660+ original_content = getattr (node , section , None )
661+ if isinstance (original_content , (list , tuple )):
662+ new_children = []
663+ section_found_target = False
664+ for child in original_content :
665+ filtered , found_target = prune_cst_for_code_hashing (child , target_functions , prefix )
666+ if filtered :
667+ new_children .append (filtered )
668+ section_found_target |= found_target
669+
670+ if section_found_target :
671+ found_any_target = True
672+ updates [section ] = new_children
673+ elif original_content is not None :
674+ filtered , found_target = prune_cst_for_code_hashing (original_content , target_functions , prefix )
675+ if found_target :
676+ found_any_target = True
677+ if filtered :
678+ updates [section ] = filtered
679+
680+ if not found_any_target :
681+ return None , False
682+
683+ return (node .with_changes (** updates ) if updates else node ), True
684+
685+
586686def prune_cst_for_read_only_code ( # noqa: PLR0911
587687 node : cst .CSTNode ,
588688 target_functions : set [str ],
0 commit comments