|  | 
|  | 1 | +"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" | 
|  | 2 | +from collections import defaultdict | 
|  | 3 | +from pathlib import Path | 
|  | 4 | +from typing import Union | 
|  | 5 | + | 
|  | 6 | +import isort | 
|  | 7 | +import libcst as cst | 
|  | 8 | + | 
|  | 9 | +from codeflash.code_utils.code_utils import get_run_tmp_file | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class LineProfilerDecoratorAdder(cst.CSTTransformer): | 
|  | 13 | +    """Transformer that adds a decorator to a function with a specific qualified name.""" | 
|  | 14 | + | 
|  | 15 | +    #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure | 
|  | 16 | +    def __init__(self, qualified_name: str, decorator_name: str): | 
|  | 17 | +        """Initialize the transformer. | 
|  | 18 | +
 | 
|  | 19 | +        Args: | 
|  | 20 | +            qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). | 
|  | 21 | +            decorator_name: The name of the decorator to add. | 
|  | 22 | +
 | 
|  | 23 | +        """ | 
|  | 24 | +        super().__init__() | 
|  | 25 | +        self.qualified_name_parts = qualified_name.split(".") | 
|  | 26 | +        self.decorator_name = decorator_name | 
|  | 27 | + | 
|  | 28 | +        # Track our current context path, only add when we encounter a class | 
|  | 29 | +        self.context_stack = [] | 
|  | 30 | + | 
|  | 31 | +    def visit_ClassDef(self, node: cst.ClassDef) -> None: | 
|  | 32 | +        # Track when we enter a class | 
|  | 33 | +        self.context_stack.append(node.name.value) | 
|  | 34 | + | 
|  | 35 | +    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: | 
|  | 36 | +        # Pop the context when we leave a class | 
|  | 37 | +        self.context_stack.pop() | 
|  | 38 | +        return updated_node | 
|  | 39 | + | 
|  | 40 | +    def visit_FunctionDef(self, node: cst.FunctionDef) -> None: | 
|  | 41 | +        # Track when we enter a function | 
|  | 42 | +        self.context_stack.append(node.name.value) | 
|  | 43 | + | 
|  | 44 | +    def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: | 
|  | 45 | +        function_name = original_node.name.value | 
|  | 46 | + | 
|  | 47 | +        # Check if the current context path matches our target qualified name | 
|  | 48 | +        if self.context_stack==self.qualified_name_parts: | 
|  | 49 | +            # Check if the decorator is already present | 
|  | 50 | +            has_decorator = any( | 
|  | 51 | +                self._is_target_decorator(decorator.decorator) | 
|  | 52 | +                for decorator in original_node.decorators | 
|  | 53 | +            ) | 
|  | 54 | + | 
|  | 55 | +            # Only add the decorator if it's not already there | 
|  | 56 | +            if not has_decorator: | 
|  | 57 | +                new_decorator = cst.Decorator( | 
|  | 58 | +                    decorator=cst.Name(value=self.decorator_name) | 
|  | 59 | +                ) | 
|  | 60 | + | 
|  | 61 | +                # Add our new decorator to the existing decorators | 
|  | 62 | +                updated_decorators = [new_decorator] + list(updated_node.decorators) | 
|  | 63 | +                updated_node = updated_node.with_changes( | 
|  | 64 | +                    decorators=tuple(updated_decorators) | 
|  | 65 | +                ) | 
|  | 66 | + | 
|  | 67 | +        # Pop the context when we leave a function | 
|  | 68 | +        self.context_stack.pop() | 
|  | 69 | +        return updated_node | 
|  | 70 | + | 
|  | 71 | +    def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool: | 
|  | 72 | +        """Check if a decorator matches our target decorator name.""" | 
|  | 73 | +        if isinstance(decorator_node, cst.Name): | 
|  | 74 | +            return decorator_node.value == self.decorator_name | 
|  | 75 | +        if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name): | 
|  | 76 | +            return decorator_node.func.value == self.decorator_name | 
|  | 77 | +        return False | 
|  | 78 | + | 
|  | 79 | +class ProfileEnableTransformer(cst.CSTTransformer): | 
|  | 80 | +    def __init__(self,filename): | 
|  | 81 | +        # Flag to track if we found the import statement | 
|  | 82 | +        self.found_import = False | 
|  | 83 | +        # Track indentation of the import statement | 
|  | 84 | +        self.import_indentation = None | 
|  | 85 | +        self.filename = filename | 
|  | 86 | + | 
|  | 87 | +    def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: | 
|  | 88 | +        # Check if this is the line profiler import statement | 
|  | 89 | +        if (isinstance(original_node.module, cst.Name) and | 
|  | 90 | +                original_node.module.value == "line_profiler" and | 
|  | 91 | +                any(name.name.value == "profile" and | 
|  | 92 | +                    (not name.asname or name.asname.name.value == "codeflash_line_profile") | 
|  | 93 | +                    for name in original_node.names)): | 
|  | 94 | + | 
|  | 95 | +            self.found_import = True | 
|  | 96 | +            # Get the indentation from the original node | 
|  | 97 | +            if hasattr(original_node, "leading_lines"): | 
|  | 98 | +                leading_whitespace = original_node.leading_lines[-1].whitespace if original_node.leading_lines else "" | 
|  | 99 | +                self.import_indentation = leading_whitespace | 
|  | 100 | + | 
|  | 101 | +        return updated_node | 
|  | 102 | + | 
|  | 103 | +    def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: | 
|  | 104 | +        if not self.found_import: | 
|  | 105 | +            return updated_node | 
|  | 106 | + | 
|  | 107 | +        # Create a list of statements from the original module | 
|  | 108 | +        new_body = list(updated_node.body) | 
|  | 109 | + | 
|  | 110 | +        # Find the index of the import statement | 
|  | 111 | +        import_index = None | 
|  | 112 | +        for i, stmt in enumerate(new_body): | 
|  | 113 | +            if isinstance(stmt, cst.SimpleStatementLine): | 
|  | 114 | +                for small_stmt in stmt.body: | 
|  | 115 | +                    if isinstance(small_stmt, cst.ImportFrom): | 
|  | 116 | +                        if (isinstance(small_stmt.module, cst.Name) and | 
|  | 117 | +                                small_stmt.module.value == "line_profiler" and | 
|  | 118 | +                                any(name.name.value == "profile" and | 
|  | 119 | +                                    (not name.asname or name.asname.name.value == "codeflash_line_profile") | 
|  | 120 | +                                    for name in small_stmt.names)): | 
|  | 121 | +                            import_index = i | 
|  | 122 | +                            break | 
|  | 123 | +                if import_index is not None: | 
|  | 124 | +                    break | 
|  | 125 | + | 
|  | 126 | +        if import_index is not None: | 
|  | 127 | +            # Create the new enable statement to insert after the import | 
|  | 128 | +            enable_statement = cst.parse_statement( | 
|  | 129 | +                f"codeflash_line_profile.enable(output_prefix='{self.filename}')" | 
|  | 130 | +            ) | 
|  | 131 | + | 
|  | 132 | +            # Insert the new statement after the import statement | 
|  | 133 | +            new_body.insert(import_index + 1, enable_statement) | 
|  | 134 | + | 
|  | 135 | +        # Create a new module with the updated body | 
|  | 136 | +        return updated_node.with_changes(body=new_body) | 
|  | 137 | + | 
|  | 138 | +def add_decorator_to_qualified_function(module, qualified_name, decorator_name): | 
|  | 139 | +    """Add a decorator to a function with the exact qualified name in the source code. | 
|  | 140 | +
 | 
|  | 141 | +    Args: | 
|  | 142 | +        module: The Python source code as a string. | 
|  | 143 | +        qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). | 
|  | 144 | +        decorator_name: The name of the decorator to add. | 
|  | 145 | +
 | 
|  | 146 | +    Returns: | 
|  | 147 | +        The modified source code as a string. | 
|  | 148 | +
 | 
|  | 149 | +    """ | 
|  | 150 | +    # Parse the source code into a CST | 
|  | 151 | + | 
|  | 152 | +    # Apply our transformer | 
|  | 153 | +    transformer = LineProfilerDecoratorAdder(qualified_name, decorator_name) | 
|  | 154 | +    modified_module = module.visit(transformer) | 
|  | 155 | + | 
|  | 156 | +    # Convert the modified CST back to source code | 
|  | 157 | +    return modified_module | 
|  | 158 | + | 
|  | 159 | +def add_profile_enable(original_code: str, line_profile_output_file: str) -> str: | 
|  | 160 | +    # TODO modify by using a libcst transformer | 
|  | 161 | +    module = cst.parse_module(original_code) | 
|  | 162 | +    transformer = ProfileEnableTransformer(line_profile_output_file) | 
|  | 163 | +    modified_module = module.visit(transformer) | 
|  | 164 | +    return modified_module.code | 
|  | 165 | + | 
|  | 166 | + | 
|  | 167 | +class ImportAdder(cst.CSTTransformer): | 
|  | 168 | +    def __init__(self, import_statement): | 
|  | 169 | +        self.import_statement = import_statement | 
|  | 170 | +        self.has_import = False | 
|  | 171 | + | 
|  | 172 | +    def leave_Module(self, original_node, updated_node): | 
|  | 173 | +        # If the import is already there, don't add it again | 
|  | 174 | +        if self.has_import: | 
|  | 175 | +            return updated_node | 
|  | 176 | + | 
|  | 177 | +        # Parse the import statement into a CST node | 
|  | 178 | +        import_node = cst.parse_statement(self.import_statement) | 
|  | 179 | + | 
|  | 180 | +        # Add the import to the module's body | 
|  | 181 | +        return updated_node.with_changes( | 
|  | 182 | +            body=[import_node] + list(updated_node.body) | 
|  | 183 | +        ) | 
|  | 184 | + | 
|  | 185 | +    def visit_ImportFrom(self, node): | 
|  | 186 | +        # Check if the profile is already imported from line_profiler | 
|  | 187 | +        if node.module and node.module.value == "line_profiler": | 
|  | 188 | +            for import_alias in node.names: | 
|  | 189 | +                if import_alias.name.value == "profile": | 
|  | 190 | +                    self.has_import = True | 
|  | 191 | + | 
|  | 192 | + | 
|  | 193 | +def add_decorator_imports(function_to_optimize, code_context): | 
|  | 194 | +    """Adds a profile decorator to a function in a Python file and all its helper functions.""" | 
|  | 195 | +    #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root | 
|  | 196 | +    #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile | 
|  | 197 | +    file_paths = defaultdict(list) | 
|  | 198 | +    line_profile_output_file = get_run_tmp_file(Path("baseline_lprof")) | 
|  | 199 | +    file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name) | 
|  | 200 | +    for elem in code_context.helper_functions: | 
|  | 201 | +        file_paths[elem.file_path].append(elem.qualified_name) | 
|  | 202 | +    for file_path,fns_present in file_paths.items(): | 
|  | 203 | +        #open file | 
|  | 204 | +        file_contents = file_path.read_text("utf-8") | 
|  | 205 | +        # parse to cst | 
|  | 206 | +        module_node = cst.parse_module(file_contents) | 
|  | 207 | +        for fn_name in fns_present: | 
|  | 208 | +            # add decorator | 
|  | 209 | +            module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile") | 
|  | 210 | +        # add imports | 
|  | 211 | +        # Create a transformer to add the import | 
|  | 212 | +        transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile") | 
|  | 213 | +        # Apply the transformer to add the import | 
|  | 214 | +        module_node = module_node.visit(transformer) | 
|  | 215 | +        modified_code = isort.code(module_node.code, float_to_top=True) | 
|  | 216 | +        # write to file | 
|  | 217 | +        with open(file_path, "w", encoding="utf-8") as file: | 
|  | 218 | +            file.write(modified_code) | 
|  | 219 | +    #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files | 
|  | 220 | +    file_contents = function_to_optimize.file_path.read_text("utf-8") | 
|  | 221 | +    modified_code = add_profile_enable(file_contents,str(line_profile_output_file)) | 
|  | 222 | +    function_to_optimize.file_path.write_text(modified_code,"utf-8") | 
|  | 223 | +    return line_profile_output_file | 
0 commit comments