@@ -73,6 +73,13 @@ def get_code_optimization_context(
7373 remove_docstrings = False ,
7474 code_context_type = CodeContextType .READ_ONLY ,
7575 )
76+ hashing_code_context = extract_code_markdown_context_from_files (
77+ helpers_of_fto_dict ,
78+ helpers_of_helpers_dict ,
79+ project_root_path ,
80+ remove_docstrings = True ,
81+ code_context_type = CodeContextType .HASHING ,
82+ )
7683
7784 # Handle token limits
7885 final_read_writable_tokens = encoded_tokens_len (final_read_writable_code )
@@ -130,6 +137,7 @@ def get_code_optimization_context(
130137 testgen_context_code = testgen_context_code ,
131138 read_writable_code = final_read_writable_code ,
132139 read_only_context_code = read_only_context_code ,
140+ hashing_code_context = hashing_code_context .markdown ,
133141 helper_functions = helpers_of_fto_list ,
134142 preexisting_objects = preexisting_objects ,
135143 )
@@ -309,20 +317,21 @@ def extract_code_markdown_context_from_files(
309317 logger .debug (f"Error while getting read-only code: { e } " )
310318 continue
311319 if code_context .strip ():
312- code_context_with_imports = CodeString (
313- code = add_needed_imports_from_module (
314- src_module_code = original_code ,
315- dst_module_code = code_context ,
316- src_path = file_path ,
317- dst_path = file_path ,
318- project_root = project_root_path ,
319- helper_functions = list (
320- helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
320+ if code_context_type != CodeContextType .HASHING :
321+ code_context = (
322+ add_needed_imports_from_module (
323+ src_module_code = original_code ,
324+ dst_module_code = code_context ,
325+ src_path = file_path ,
326+ dst_path = file_path ,
327+ project_root = project_root_path ,
328+ helper_functions = list (
329+ helpers_of_fto .get (file_path , set ()) | helpers_of_helpers .get (file_path , set ())
330+ ),
321331 ),
322- ),
323- file_path = file_path .relative_to (project_root_path ),
324- )
325- code_context_markdown .code_strings .append (code_context_with_imports )
332+ )
333+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
334+ code_context_markdown .code_strings .append (code_string_context )
326335 # Extract code from file paths containing helpers of helpers
327336 for file_path , helper_function_sources in helpers_of_helpers_no_overlap .items ():
328337 try :
@@ -343,18 +352,19 @@ def extract_code_markdown_context_from_files(
343352 continue
344353
345354 if code_context .strip ():
346- code_context_with_imports = CodeString (
347- code = add_needed_imports_from_module (
348- src_module_code = original_code ,
349- dst_module_code = code_context ,
350- src_path = file_path ,
351- dst_path = file_path ,
352- project_root = project_root_path ,
353- helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
354- ),
355- file_path = file_path .relative_to (project_root_path ),
356- )
357- code_context_markdown .code_strings .append (code_context_with_imports )
355+ if code_context_type != CodeContextType .HASHING :
356+ code_context = (
357+ add_needed_imports_from_module (
358+ src_module_code = original_code ,
359+ dst_module_code = code_context ,
360+ src_path = file_path ,
361+ dst_path = file_path ,
362+ project_root = project_root_path ,
363+ helper_functions = list (helpers_of_helpers_no_overlap .get (file_path , set ())),
364+ ),
365+ )
366+ code_string_context = CodeString (code = code_context , file_path = file_path .relative_to (project_root_path ))
367+ code_context_markdown .code_strings .append (code_string_context )
358368 return code_context_markdown
359369
360370
@@ -492,6 +502,8 @@ def parse_code_and_prune_cst(
492502 filtered_node , found_target = prune_cst_for_testgen_code (
493503 module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
494504 )
505+ elif code_context_type == CodeContextType .HASHING :
506+ filtered_node , found_target = prune_cst_for_code_hashing (module , target_functions )
495507 else :
496508 raise ValueError (f"Unknown code_context_type: { code_context_type } " ) # noqa: EM102
497509
@@ -583,6 +595,87 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583595 return (node .with_changes (** updates ) if updates else node ), True
584596
585597
598+ def prune_cst_for_code_hashing ( # noqa: PLR0911
599+ node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
600+ ) -> tuple [cst .CSTNode | None , bool ]:
601+ """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
602+
603+ Returns
604+ -------
605+ (filtered_node, found_target):
606+ filtered_node: The modified CST node or None if it should be removed.
607+ found_target: True if a target function was found in this node's subtree.
608+
609+ """
610+ if isinstance (node , (cst .Import , cst .ImportFrom )):
611+ return None , False
612+
613+ if isinstance (node , cst .FunctionDef ):
614+ qualified_name = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
615+ if qualified_name in target_functions :
616+ new_body = remove_docstring_from_body (node .body )
617+ return node .with_changes (body = new_body ), True
618+ return None , False
619+
620+ if isinstance (node , cst .ClassDef ):
621+ # Do not recurse into nested classes
622+ if prefix :
623+ return None , False
624+ # Assuming always an IndentedBlock
625+ if not isinstance (node .body , cst .IndentedBlock ):
626+ raise ValueError ("ClassDef body is not an IndentedBlock" ) # noqa: TRY004
627+ class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
628+ new_body = []
629+ found_target = False
630+
631+ for stmt in node .body .body :
632+ if isinstance (stmt , cst .FunctionDef ):
633+ qualified_name = f"{ class_prefix } .{ stmt .name .value } "
634+ if qualified_name in target_functions :
635+ new_body .append (stmt )
636+ found_target = True
637+ # If no target functions found, remove the class entirely
638+ if not new_body or not found_target :
639+ return None , False
640+ return node .with_changes (
641+ body = remove_docstring_from_body (node .body .with_changes (body = new_body ))
642+ ) if new_body else None , True
643+
644+ # For other nodes, we preserve them only if they contain target functions in their children.
645+ section_names = get_section_names (node )
646+ if not section_names :
647+ return node , False
648+
649+ updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
650+ found_any_target = False
651+
652+ for section in section_names :
653+ original_content = getattr (node , section , None )
654+ if isinstance (original_content , (list , tuple )):
655+ new_children = []
656+ section_found_target = False
657+ for child in original_content :
658+ filtered , found_target = prune_cst_for_code_hashing (child , target_functions , prefix )
659+ if filtered :
660+ new_children .append (filtered )
661+ section_found_target |= found_target
662+
663+ if section_found_target :
664+ found_any_target = True
665+ updates [section ] = new_children
666+ elif original_content is not None :
667+ filtered , found_target = prune_cst_for_code_hashing (original_content , target_functions , prefix )
668+ if found_target :
669+ found_any_target = True
670+ if filtered :
671+ updates [section ] = filtered
672+
673+ if not found_any_target :
674+ return None , False
675+
676+ return (node .with_changes (** updates ) if updates else node ), True
677+
678+
586679def prune_cst_for_read_only_code ( # noqa: PLR0911
587680 node : cst .CSTNode ,
588681 target_functions : set [str ],
0 commit comments