33import os
44from collections import defaultdict
55from itertools import chain
6- from pathlib import Path
6+ from typing import TYPE_CHECKING , Optional
77
88import jedi
99import libcst as cst
1010import tiktoken
11- from jedi .api .classes import Name
12- from libcst import CSTNode
1311
1412from codeflash .cli_cmds .console import logger
1513from codeflash .code_utils .code_extractor import add_needed_imports_from_module , find_preexisting_objects
1614from codeflash .code_utils .code_utils import get_qualified_name , path_belongs_to_site_packages
17- from codeflash .discovery .functions_to_optimize import FunctionToOptimize
1815from codeflash .models .models import (
1916 CodeContextType ,
2017 CodeOptimizationContext ,
2421)
2522from codeflash .optimization .function_context import belongs_to_function_qualified
2623
24+ if TYPE_CHECKING :
25+ from pathlib import Path
26+
27+ from jedi .api .classes import Name
28+ from libcst import CSTNode
29+
30+ from codeflash .discovery .functions_to_optimize import FunctionToOptimize
31+
2732
2833def get_code_optimization_context (
2934 function_to_optimize : FunctionToOptimize ,
@@ -75,7 +80,8 @@ def get_code_optimization_context(
7580 tokenizer = tiktoken .encoding_for_model ("gpt-4o" )
7681 final_read_writable_tokens = len (tokenizer .encode (final_read_writable_code ))
7782 if final_read_writable_tokens > optim_token_limit :
78- raise ValueError ("Read-writable code has exceeded token limit, cannot proceed" )
83+ msg = "Read-writable code has exceeded token limit, cannot proceed"
84+ raise ValueError (msg )
7985
8086 # Setup preexisting objects for code replacer
8187 preexisting_objects = set (
@@ -122,7 +128,8 @@ def get_code_optimization_context(
122128 testgen_context_code = testgen_code_markdown .code
123129 testgen_context_code_tokens = len (tokenizer .encode (testgen_context_code ))
124130 if testgen_context_code_tokens > testgen_token_limit :
125- raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
131+ msg = "Testgen code context has exceeded token limit, cannot proceed"
132+ raise ValueError (msg )
126133
127134 return CodeOptimizationContext (
128135 testgen_context_code = testgen_context_code ,
@@ -143,7 +150,7 @@ def extract_code_string_context_from_files(
143150 """Extract code context from files containing target functions and their helpers.
144151 This function processes two sets of files:
145152 1. Files containing the function to optimize (fto) and their first-degree helpers
146- 2. Files containing only helpers of helpers (with no overlap with the first set)
153+ 2. Files containing only helpers of helpers (with no overlap with the first set).
147154
148155 For each file, it extracts relevant code based on the specified context type, adds necessary
149156 imports, and combines them.
@@ -358,18 +365,18 @@ def get_function_to_optimize_as_function_source(
358365 and name .name == function_to_optimize .function_name
359366 and get_qualified_name (name .module_name , name .full_name ) == function_to_optimize .qualified_name
360367 ):
361- function_source = FunctionSource (
368+ return FunctionSource (
362369 file_path = function_to_optimize .file_path ,
363370 qualified_name = function_to_optimize .qualified_name ,
364371 fully_qualified_name = name .full_name ,
365372 only_function_name = name .name ,
366373 source_code = name .get_line_code (),
367374 jedi_definition = name ,
368375 )
369- return function_source
370376
377+ msg = f"Could not find function { function_to_optimize .function_name } in { function_to_optimize .file_path } "
371378 raise ValueError (
372- f"Could not find function { function_to_optimize . function_name } in { function_to_optimize . file_path } "
379+ msg
373380 )
374381
375382
@@ -436,7 +443,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
436443
437444
438445def remove_docstring_from_body (indented_block : cst .IndentedBlock ) -> cst .CSTNode :
439- """Removes the docstring from an indented block if it exists"""
446+ """Removes the docstring from an indented block if it exists. """
440447 if not isinstance (indented_block .body [0 ], cst .SimpleStatementLine ):
441448 return indented_block
442449 first_stmt = indented_block .body [0 ].body [0 ]
@@ -449,10 +456,12 @@ def parse_code_and_prune_cst(
449456 code : str ,
450457 code_context_type : CodeContextType ,
451458 target_functions : set [str ],
452- helpers_of_helper_functions : set [str ] = set () ,
459+ helpers_of_helper_functions : Optional [ set [str ]] = None ,
453460 remove_docstrings : bool = False ,
454461) -> str :
455462 """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
463+ if helpers_of_helper_functions is None :
464+ helpers_of_helper_functions = set ()
456465 module = cst .parse_module (code )
457466 if code_context_type == CodeContextType .READ_WRITABLE :
458467 filtered_node , found_target = prune_cst_for_read_writable_code (module , target_functions )
@@ -465,10 +474,12 @@ def parse_code_and_prune_cst(
465474 module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
466475 )
467476 else :
468- raise ValueError (f"Unknown code_context_type: { code_context_type } " )
477+ msg = f"Unknown code_context_type: { code_context_type } "
478+ raise ValueError (msg )
469479
470480 if not found_target :
471- raise ValueError ("No target functions found in the provided code" )
481+ msg = "No target functions found in the provided code"
482+ raise ValueError (msg )
472483 if filtered_node and isinstance (filtered_node , cst .Module ):
473484 return str (filtered_node .code )
474485 return ""
@@ -500,7 +511,8 @@ def prune_cst_for_read_writable_code(
500511 return None , False
501512 # Assuming always an IndentedBlock
502513 if not isinstance (node .body , cst .IndentedBlock ):
503- raise ValueError ("ClassDef body is not an IndentedBlock" )
514+ msg = "ClassDef body is not an IndentedBlock"
515+ raise ValueError (msg )
504516 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
505517 new_body = []
506518 found_target = False
@@ -593,7 +605,8 @@ def prune_cst_for_read_only_code(
593605 return None , False
594606 # Assuming always an IndentedBlock
595607 if not isinstance (node .body , cst .IndentedBlock ):
596- raise ValueError ("ClassDef body is not an IndentedBlock" )
608+ msg = "ClassDef body is not an IndentedBlock"
609+ raise ValueError (msg )
597610
598611 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
599612
@@ -698,7 +711,8 @@ def prune_cst_for_testgen_code(
698711 return None , False
699712 # Assuming always an IndentedBlock
700713 if not isinstance (node .body , cst .IndentedBlock ):
701- raise ValueError ("ClassDef body is not an IndentedBlock" )
714+ msg = "ClassDef body is not an IndentedBlock"
715+ raise ValueError (msg )
702716
703717 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
704718
0 commit comments