Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 80 additions & 41 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""

from collections import defaultdict
from pathlib import Path
from typing import Union
Expand All @@ -12,7 +13,7 @@
class LineProfilerDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds a decorator to a function with a specific qualified name."""

#TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
# TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
def __init__(self, qualified_name: str, decorator_name: str):
"""Initialize the transformer.

Expand Down Expand Up @@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
function_name = original_node.name.value

# Check if the current context path matches our target qualified name
if self.context_stack==self.qualified_name_parts:
if self.context_stack == self.qualified_name_parts:
# Check if the decorator is already present
has_decorator = any(
self._is_target_decorator(decorator.decorator)
for decorator in original_node.decorators
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
)

# Only add the decorator if it's not already there
if not has_decorator:
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name)
)
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))

# Add our new decorator to the existing decorators
updated_decorators = [new_decorator] + list(updated_node.decorators)
updated_node = updated_node.with_changes(
decorators=tuple(updated_decorators)
)
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))

# Pop the context when we leave a function
self.context_stack.pop()
Expand All @@ -76,22 +72,21 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
return decorator_node.func.value == self.decorator_name
return False


class ProfileEnableTransformer(cst.CSTTransformer):
def __init__(self,filename):
# Flag to track if we found the import statement
self.found_import = False
# Track indentation of the import statement
self.import_indentation = None
self.filename = filename
def __init__(self, line_profile_output_file: str):
self.line_profile_output_file = line_profile_output_file

def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
# Check if this is the line profiler import statement
if (isinstance(original_node.module, cst.Name) and
original_node.module.value == "line_profiler" and
any(name.name.value == "profile" and
(not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in original_node.names)):

if (
isinstance(original_node.module, cst.Name)
and original_node.module.value == "line_profiler"
and any(
name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in original_node.names
)
):
self.found_import = True
# Get the indentation from the original node
if hasattr(original_node, "leading_lines"):
Expand All @@ -113,28 +108,39 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
if isinstance(stmt, cst.SimpleStatementLine):
for small_stmt in stmt.body:
if isinstance(small_stmt, cst.ImportFrom):
if (isinstance(small_stmt.module, cst.Name) and
small_stmt.module.value == "line_profiler" and
any(name.name.value == "profile" and
(not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in small_stmt.names)):
if (
isinstance(small_stmt.module, cst.Name)
and small_stmt.module.value == "line_profiler"
and any(
name.name.value == "profile"
and (not name.asname or name.asname.name.value == "codeflash_line_profile")
for name in small_stmt.names
)
):
import_index = i
break
if import_index is not None:
break

if import_index is not None:
# Create the new enable statement to insert after the import
enable_statement = cst.parse_statement(
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
)
enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')")

# Insert the new statement after the import statement
new_body.insert(import_index + 1, enable_statement)

# Create a new module with the updated body
return updated_node.with_changes(body=new_body)

def __init__(self, line_profile_output_file: str):
self.line_profile_output_file = line_profile_output_file

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# This is a simplified example of the transformation logic
new_decorator = cst.Decorator(decorator=cst.Name(value="codeflash_line_profile"))
return updated_node.with_changes(decorators=[*updated_node.decorators, new_decorator])


def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
"""Add a decorator to a function with the exact qualified name in the source code.

Expand All @@ -156,9 +162,20 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
# Convert the modified CST back to source code
return modified_module


def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
# TODO modify by using a libcst transformer
# Avoid unnecessary transformations
if not original_code.strip():
return original_code

# Parse the module only once
module = cst.parse_module(original_code)

# If we can determine whether the transformer needs to be applied, we can shortcut
if not has_transformable_content(module):
return original_code

# Apply transformer optimally
transformer = ProfileEnableTransformer(line_profile_output_file)
modified_module = module.visit(transformer)
return modified_module.code
Expand All @@ -178,9 +195,7 @@ def leave_Module(self, original_node, updated_node):
import_node = cst.parse_statement(self.import_statement)

# Add the import to the module's body
return updated_node.with_changes(
body=[import_node] + list(updated_node.body)
)
return updated_node.with_changes(body=[import_node] + list(updated_node.body))

def visit_ImportFrom(self, node):
# Check if the profile is already imported from line_profiler
Expand All @@ -192,15 +207,15 @@ def visit_ImportFrom(self, node):

def add_decorator_imports(function_to_optimize, code_context):
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
# self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
# grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
file_paths = defaultdict(list)
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
for elem in code_context.helper_functions:
file_paths[elem.file_path].append(elem.qualified_name)
for file_path,fns_present in file_paths.items():
#open file
for file_path, fns_present in file_paths.items():
# open file
file_contents = file_path.read_text("utf-8")
# parse to cst
module_node = cst.parse_module(file_contents)
Expand All @@ -216,8 +231,32 @@ def add_decorator_imports(function_to_optimize, code_context):
# write to file
with open(file_path, "w", encoding="utf-8") as file:
file.write(modified_code)
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
file_contents = function_to_optimize.file_path.read_text("utf-8")
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
function_to_optimize.file_path.write_text(modified_code,"utf-8")
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
function_to_optimize.file_path.write_text(modified_code, "utf-8")
return line_profile_output_file


def has_transformable_content(module) -> bool:
"""Function to quickly check if the module has content that needs to be transformed.
This can help in reducing unnecessary transformations.
"""

# A simple check to see if the transformer is needed (can be more complex as required)
# For example, checking if the profile decorators are already present
class CheckVisitor(cst.CSTVisitor):
def __init__(self):
self.has_target = False

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
for deco in node.decorators:
if deco.decorator.value == "codeflash_line_profile":
self.has_target = True
return False # Stop visiting further

return True

visitor = CheckVisitor()
module.visit(visitor)
return visitor.has_target
Loading