33import os
44from collections import defaultdict
55from itertools import chain
6- from pathlib import Path
6+ from typing import TYPE_CHECKING , Optional
77
88import jedi
99import libcst as cst
1010import tiktoken
11- from jedi .api .classes import Name
12- from libcst import CSTNode
1311
1412from codeflash .cli_cmds .console import logger
1513from codeflash .code_utils .code_extractor import add_needed_imports_from_module , find_preexisting_objects
1614from codeflash .code_utils .code_utils import get_qualified_name , path_belongs_to_site_packages
17- from codeflash .discovery .functions_to_optimize import FunctionToOptimize
1815from codeflash .models .models import (
1916 CodeContextType ,
2017 CodeOptimizationContext ,
2421)
2522from codeflash .optimization .function_context import belongs_to_function_qualified
2623
24+ if TYPE_CHECKING :
25+ from pathlib import Path
26+
27+ from jedi .api .classes import Name
28+ from libcst import CSTNode
29+
30+ from codeflash .discovery .functions_to_optimize import FunctionToOptimize
31+ from typing import Callable
32+
2733
2834def get_code_optimization_context (
2935 function_to_optimize : FunctionToOptimize ,
@@ -75,7 +81,8 @@ def get_code_optimization_context(
7581 tokenizer = tiktoken .encoding_for_model ("gpt-4o" )
7682 final_read_writable_tokens = len (tokenizer .encode (final_read_writable_code ))
7783 if final_read_writable_tokens > optim_token_limit :
78- raise ValueError ("Read-writable code has exceeded token limit, cannot proceed" )
84+ msg = "Read-writable code has exceeded token limit, cannot proceed"
85+ raise ValueError (msg )
7986
8087 # Setup preexisting objects for code replacer
8188 preexisting_objects = set (
@@ -122,7 +129,8 @@ def get_code_optimization_context(
122129 testgen_context_code = testgen_code_markdown .code
123130 testgen_context_code_tokens = len (tokenizer .encode (testgen_context_code ))
124131 if testgen_context_code_tokens > testgen_token_limit :
125- raise ValueError ("Testgen code context has exceeded token limit, cannot proceed" )
132+ msg = "Testgen code context has exceeded token limit, cannot proceed"
133+ raise ValueError (msg )
126134
127135 return CodeOptimizationContext (
128136 testgen_context_code = testgen_context_code ,
@@ -143,7 +151,7 @@ def extract_code_string_context_from_files(
143151 """Extract code context from files containing target functions and their helpers.
144152 This function processes two sets of files:
145153 1. Files containing the function to optimize (fto) and their first-degree helpers
146- 2. Files containing only helpers of helpers (with no overlap with the first set)
154+ 2. Files containing only helpers of helpers (with no overlap with the first set).
147155
148156 For each file, it extracts relevant code based on the specified context type, adds necessary
149157 imports, and combines them.
@@ -358,19 +366,17 @@ def get_function_to_optimize_as_function_source(
358366 and name .name == function_to_optimize .function_name
359367 and get_qualified_name (name .module_name , name .full_name ) == function_to_optimize .qualified_name
360368 ):
361- function_source = FunctionSource (
369+ return FunctionSource (
362370 file_path = function_to_optimize .file_path ,
363371 qualified_name = function_to_optimize .qualified_name ,
364372 fully_qualified_name = name .full_name ,
365373 only_function_name = name .name ,
366374 source_code = name .get_line_code (),
367375 jedi_definition = name ,
368376 )
369- return function_source
370377
371- raise ValueError (
372- f"Could not find function { function_to_optimize .function_name } in { function_to_optimize .file_path } "
373- )
378+ msg = f"Could not find function { function_to_optimize .function_name } in { function_to_optimize .file_path } "
379+ raise ValueError (msg )
374380
375381
376382def get_function_sources_from_jedi (
@@ -436,7 +442,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
436442
437443
438444def remove_docstring_from_body (indented_block : cst .IndentedBlock ) -> cst .CSTNode :
439- """Removes the docstring from an indented block if it exists"""
445+ """Removes the docstring from an indented block if it exists. """
440446 if not isinstance (indented_block .body [0 ], cst .SimpleStatementLine ):
441447 return indented_block
442448 first_stmt = indented_block .body [0 ].body [0 ]
@@ -464,16 +470,13 @@ def parse_code_and_prune_cst(
464470 filtered_node , found_target = prune_cst_for_testgen_code (
465471 module , target_functions , helpers_of_helper_functions , remove_docstrings = remove_docstrings
466472 )
467- else :
468- raise ValueError (f"Unknown code_context_type: { code_context_type } " )
469-
470473 if not found_target :
471- raise ValueError ("No target functions found in the provided code" )
474+ msg = "No target functions found in the provided code"
475+ raise ValueError (msg )
472476 if filtered_node and isinstance (filtered_node , cst .Module ):
473477 return str (filtered_node .code )
474478 return ""
475479
476-
477480def prune_cst_for_read_writable_code (
478481 node : cst .CSTNode , target_functions : set [str ], prefix : str = ""
479482) -> tuple [cst .CSTNode | None , bool ]:
@@ -500,7 +503,8 @@ def prune_cst_for_read_writable_code(
500503 return None , False
501504 # Assuming always an IndentedBlock
502505 if not isinstance (node .body , cst .IndentedBlock ):
503- raise ValueError ("ClassDef body is not an IndentedBlock" )
506+ msg = "ClassDef body is not an IndentedBlock"
507+ raise ValueError (msg )
504508 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
505509 new_body = []
506510 found_target = False
@@ -593,7 +597,8 @@ def prune_cst_for_read_only_code(
593597 return None , False
594598 # Assuming always an IndentedBlock
595599 if not isinstance (node .body , cst .IndentedBlock ):
596- raise ValueError ("ClassDef body is not an IndentedBlock" )
600+ msg = "ClassDef body is not an IndentedBlock"
601+ raise ValueError (msg )
597602
598603 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
599604
@@ -612,10 +617,12 @@ def prune_cst_for_read_only_code(
612617 return None , False
613618
614619 if remove_docstrings :
615- return node .with_changes (
616- body = remove_docstring_from_body (node .body .with_changes (body = new_class_body ))
617- ) if new_class_body else None , True
618- return node .with_changes (body = node .body .with_changes (body = new_class_body )) if new_class_body else None , True
620+ return (
621+ node .with_changes (body = remove_docstring_from_body (node .body .with_changes (body = new_class_body )))
622+ if new_class_body
623+ else None
624+ ), True
625+ return (node .with_changes (body = node .body .with_changes (body = new_class_body )) if new_class_body else None ), True
619626
620627 # For other nodes, keep the node and recursively filter children
621628 section_names = get_section_names (node )
@@ -698,7 +705,8 @@ def prune_cst_for_testgen_code(
698705 return None , False
699706 # Assuming always an IndentedBlock
700707 if not isinstance (node .body , cst .IndentedBlock ):
701- raise ValueError ("ClassDef body is not an IndentedBlock" )
708+ msg = "ClassDef body is not an IndentedBlock"
709+ raise ValueError (msg )
702710
703711 class_prefix = f"{ prefix } .{ node .name .value } " if prefix else node .name .value
704712
@@ -717,10 +725,12 @@ def prune_cst_for_testgen_code(
717725 return None , False
718726
719727 if remove_docstrings :
720- return node .with_changes (
721- body = remove_docstring_from_body (node .body .with_changes (body = new_class_body ))
722- ) if new_class_body else None , True
723- return node .with_changes (body = node .body .with_changes (body = new_class_body )) if new_class_body else None , True
728+ return (
729+ node .with_changes (body = remove_docstring_from_body (node .body .with_changes (body = new_class_body )))
730+ if new_class_body
731+ else None
732+ ), True
733+ return (node .with_changes (body = node .body .with_changes (body = new_class_body )) if new_class_body else None ), True
724734
725735 # For other nodes, keep the node and recursively filter children
726736 section_names = get_section_names (node )
0 commit comments