1- from pathlib import Path
1+ from __future__ import annotations
2+
3+ from typing import TYPE_CHECKING , Optional , Union
24
35import isort
46import libcst as cst
57
6- from codeflash .discovery .functions_to_optimize import FunctionToOptimize
8+ if TYPE_CHECKING :
9+ from pathlib import Path
10+
11+ from libcst import BaseStatement , ClassDef , FlattenSentinel , FunctionDef , RemovalSentinel
12+
13+ from codeflash .discovery .functions_to_optimize import FunctionToOptimize
714
815
916class AddDecoratorTransformer (cst .CSTTransformer ):
@@ -13,57 +20,48 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1320 self .added_codeflash_trace = False
1421 self .class_name = ""
1522 self .function_name = ""
16- self .decorator = cst .Decorator (
17- decorator = cst .Name (value = "codeflash_trace" )
18- )
23+ self .decorator = cst .Decorator (decorator = cst .Name (value = "codeflash_trace" ))
1924
20- def leave_ClassDef (self , original_node , updated_node ):
25+ def leave_ClassDef (
26+ self , original_node : ClassDef , updated_node : ClassDef
27+ ) -> Union [BaseStatement , FlattenSentinel [BaseStatement ], RemovalSentinel ]:
2128 if self .class_name == original_node .name .value :
22- self .class_name = "" # Even if nested classes are not visited, this function is still called on them
29+ self .class_name = "" # Even if nested classes are not visited, this function is still called on them
2330 return updated_node
2431
25- def visit_ClassDef (self , node ) :
26- if self .class_name : # Don't go into nested class
32+ def visit_ClassDef (self , node : ClassDef ) -> Optional [ bool ] :
33+ if self .class_name : # Don't go into nested class
2734 return False
28- self .class_name = node .name .value
35+ self .class_name = node .name .value # noqa: RET503
2936
30- def visit_FunctionDef (self , node ) :
31- if self .function_name : # Don't go into nested function
37+ def visit_FunctionDef (self , node : FunctionDef ) -> Optional [ bool ] :
38+ if self .function_name : # Don't go into nested function
3239 return False
33- self .function_name = node .name .value
40+ self .function_name = node .name .value # noqa: RET503
3441
35- def leave_FunctionDef (self , original_node , updated_node ) :
42+ def leave_FunctionDef (self , original_node : FunctionDef , updated_node : FunctionDef ) -> FunctionDef :
3643 if self .function_name == original_node .name .value :
3744 self .function_name = ""
3845 if (self .class_name , original_node .name .value ) in self .target_functions :
3946 # Add the new decorator after any existing decorators, so it gets executed first
40- updated_decorators = list (updated_node .decorators ) + [ self .decorator ]
47+ updated_decorators = [ * list (updated_node .decorators ), self .decorator ]
4148 self .added_codeflash_trace = True
42- return updated_node .with_changes (
43- decorators = updated_decorators
44- )
49+ return updated_node .with_changes (decorators = updated_decorators )
4550
4651 return updated_node
4752
48- def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
53+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
4954 # Create import statement for codeflash_trace
5055 if not self .added_codeflash_trace :
5156 return updated_node
5257 import_stmt = cst .SimpleStatementLine (
5358 body = [
5459 cst .ImportFrom (
5560 module = cst .Attribute (
56- value = cst .Attribute (
57- value = cst .Name (value = "codeflash" ),
58- attr = cst .Name (value = "benchmarking" )
59- ),
60- attr = cst .Name (value = "codeflash_trace" )
61+ value = cst .Attribute (value = cst .Name (value = "codeflash" ), attr = cst .Name (value = "benchmarking" )),
62+ attr = cst .Name (value = "codeflash_trace" ),
6163 ),
62- names = [
63- cst .ImportAlias (
64- name = cst .Name (value = "codeflash_trace" )
65- )
66- ]
64+ names = [cst .ImportAlias (name = cst .Name (value = "codeflash_trace" ))],
6765 )
6866 ]
6967 )
@@ -73,12 +71,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
7371
7472 return updated_node .with_changes (body = new_body )
7573
74+
7675def add_codeflash_decorator_to_code (code : str , functions_to_optimize : list [FunctionToOptimize ]) -> str :
7776 """Add codeflash_trace to a function.
7877
7978 Args:
8079 code: The source code as a string
81- function_to_optimize: The FunctionToOptimize instance containing function details
80+ functions_to_optimize: List of FunctionToOptimize instances containing function details
8281
8382 Returns:
8483 The modified source code as a string
@@ -91,25 +90,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
9190 class_name = function_to_optimize .parents [0 ].name
9291 target_functions .add ((class_name , function_to_optimize .function_name ))
9392
94- transformer = AddDecoratorTransformer (
95- target_functions = target_functions ,
96- )
93+ transformer = AddDecoratorTransformer (target_functions = target_functions )
9794
9895 module = cst .parse_module (code )
9996 modified_module = module .visit (transformer )
10097 return modified_module .code
10198
10299
103- def instrument_codeflash_trace_decorator (
104- file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]
105- ) -> None :
100+ def instrument_codeflash_trace_decorator (file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]) -> None :
106101 """Instrument codeflash_trace decorator to functions to optimize."""
107102 for file_path , functions_to_optimize in file_to_funcs_to_optimize .items ():
108103 original_code = file_path .read_text (encoding = "utf-8" )
109- new_code = add_codeflash_decorator_to_code (
110- original_code ,
111- functions_to_optimize
112- )
104+ new_code = add_codeflash_decorator_to_code (original_code , functions_to_optimize )
113105 # Modify the code
114106 modified_code = isort .code (code = new_code , float_to_top = True )
115107
0 commit comments