Skip to content

Commit 4384484

Browse files
⚡️ Speed up method LineProfilerDecoratorAdder._matches_qualified_path by 258% in PR #35 (line-profiler)
To optimize the given Python program, let's focus on improving the efficiency and readability of the code by using more efficient operations and making a few structural improvements. Here's a more optimized version.
1 parent d8246fc commit 4384484

File tree

1 file changed

+52
-58
lines changed

1 file changed

+52
-58
lines changed
Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
23
from collections import defaultdict
4+
from pathlib import Path
5+
from typing import Union
36

47
import isort
58
import libcst as cst
6-
from pathlib import Path
7-
from typing import Union, List
8-
from libcst import ImportFrom, ImportAlias, Name
99

1010
from codeflash.code_utils.code_utils import get_run_tmp_file
1111

1212

1313
class LineProfilerDecoratorAdder(cst.CSTTransformer):
1414
"""Transformer that adds a decorator to a function with a specific qualified name."""
15-
#Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
15+
1616
def __init__(self, qualified_name: str, decorator_name: str):
17-
"""
18-
Initialize the transformer.
17+
"""Initialize the transformer.
1918
2019
Args:
21-
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
20+
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.target_func").
2221
decorator_name: The name of the decorator to add.
22+
2323
"""
2424
super().__init__()
2525
self.qualified_name_parts = qualified_name.split(".")
2626
self.decorator_name = decorator_name
27-
28-
# Track our current context path, only add when we encounter a class
2927
self.context_stack = []
3028

3129
def visit_ClassDef(self, node: cst.ClassDef) -> None:
@@ -48,47 +46,36 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4846
if self._matches_qualified_path():
4947
# Check if the decorator is already present
5048
has_decorator = any(
51-
self._is_target_decorator(decorator.decorator)
52-
for decorator in original_node.decorators
49+
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
5350
)
5451

5552
# Only add the decorator if it's not already there
5653
if not has_decorator:
57-
new_decorator = cst.Decorator(
58-
decorator=cst.Name(value=self.decorator_name)
59-
)
54+
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
6055

6156
# Add our new decorator to the existing decorators
6257
updated_decorators = [new_decorator] + list(updated_node.decorators)
63-
updated_node = updated_node.with_changes(
64-
decorators=tuple(updated_decorators)
65-
)
58+
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
6659

6760
# Pop the context when we leave a function
6861
self.context_stack.pop()
6962
return updated_node
7063

7164
def _matches_qualified_path(self) -> bool:
7265
"""Check if the current context stack matches the qualified name."""
73-
if len(self.context_stack) != len(self.qualified_name_parts):
74-
return False
75-
76-
for i, name in enumerate(self.qualified_name_parts):
77-
if self.context_stack[i] != name:
78-
return False
79-
80-
return True
66+
return self.context_stack == self.qualified_name_parts
8167

8268
def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool:
8369
"""Check if a decorator matches our target decorator name."""
8470
if isinstance(decorator_node, cst.Name):
8571
return decorator_node.value == self.decorator_name
86-
elif isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
72+
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
8773
return decorator_node.func.value == self.decorator_name
8874
return False
8975

76+
9077
class ProfileEnableTransformer(cst.CSTTransformer):
91-
def __init__(self,filename):
78+
def __init__(self, filename):
9279
# Flag to track if we found the import statement
9380
self.found_import = False
9481
# Track indentation of the import statement
@@ -97,12 +84,14 @@ def __init__(self,filename):
9784

9885
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
9986
# Check if this is the line profiler import statement
100-
if (isinstance(original_node.module, cst.Name) and
101-
original_node.module.value == "line_profiler" and
102-
any(name.name.value == "profile" and
103-
(not name.asname or name.asname.name.value == "codeflash_line_profile")
104-
for name in original_node.names)):
105-
87+
if (
88+
isinstance(original_node.module, cst.Name)
89+
and original_node.module.value == "line_profiler"
90+
and any(
91+
name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile")
92+
for name in original_node.names
93+
)
94+
):
10695
self.found_import = True
10796
# Get the indentation from the original node
10897
if hasattr(original_node, "leading_lines"):
@@ -124,31 +113,33 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
124113
if isinstance(stmt, cst.SimpleStatementLine):
125114
for small_stmt in stmt.body:
126115
if isinstance(small_stmt, cst.ImportFrom):
127-
if (isinstance(small_stmt.module, cst.Name) and
128-
small_stmt.module.value == "line_profiler" and
129-
any(name.name.value == "profile" and
130-
(not name.asname or name.asname.name.value == "codeflash_line_profile")
131-
for name in small_stmt.names)):
116+
if (
117+
isinstance(small_stmt.module, cst.Name)
118+
and small_stmt.module.value == "line_profiler"
119+
and any(
120+
name.name.value == "profile"
121+
and (not name.asname or name.asname.name.value == "codeflash_line_profile")
122+
for name in small_stmt.names
123+
)
124+
):
132125
import_index = i
133126
break
134127
if import_index is not None:
135128
break
136129

137130
if import_index is not None:
138131
# Create the new enable statement to insert after the import
139-
enable_statement = cst.parse_statement(
140-
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
141-
)
132+
enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')")
142133

143134
# Insert the new statement after the import statement
144135
new_body.insert(import_index + 1, enable_statement)
145136

146137
# Create a new module with the updated body
147138
return updated_node.with_changes(body=new_body)
148139

140+
149141
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
150-
"""
151-
Add a decorator to a function with the exact qualified name in the source code.
142+
"""Add a decorator to a function with the exact qualified name in the source code.
152143
153144
Args:
154145
module: The Python source code as a string.
@@ -157,6 +148,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
157148
158149
Returns:
159150
The modified source code as a string.
151+
160152
"""
161153
# Parse the source code into a CST
162154

@@ -167,8 +159,9 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
167159
# Convert the modified CST back to source code
168160
return modified_module
169161

162+
170163
def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
171-
# todo modify by using a libcst transformer
164+
# TODO modify by using a libcst transformer
172165
module = cst.parse_module(original_code)
173166
transformer = ProfileEnableTransformer(line_profile_output_file)
174167
modified_module = module.visit(transformer)
@@ -189,9 +182,7 @@ def leave_Module(self, original_node, updated_node):
189182
import_node = cst.parse_statement(self.import_statement)
190183

191184
# Add the import to the module's body
192-
return updated_node.with_changes(
193-
body=[import_node] + list(updated_node.body)
194-
)
185+
return updated_node.with_changes(body=[import_node] + list(updated_node.body))
195186

196187
def visit_ImportFrom(self, node):
197188
# Check if the profile is already imported from line_profiler
@@ -203,21 +194,22 @@ def visit_ImportFrom(self, node):
203194

204195
def add_decorator_imports(function_to_optimize, code_context):
205196
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
206-
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
207-
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
197+
# self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
198+
# grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
208199
file_paths = defaultdict(list)
209200
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
210201
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
211202
for elem in code_context.helper_functions:
212203
file_paths[elem.file_path].append(elem.qualified_name)
213-
for file_path,fns_present in file_paths.items():
214-
#open file
215-
file_contents = file_path.read_text("utf-8")
204+
for file_path, fns_present in file_paths.items():
205+
# open file
206+
with open(file_path, encoding="utf-8") as file:
207+
file_contents = file.read()
216208
# parse to cst
217209
module_node = cst.parse_module(file_contents)
218210
for fn_name in fns_present:
219211
# add decorator
220-
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'codeflash_line_profile')
212+
module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile")
221213
# add imports
222214
# Create a transformer to add the import
223215
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
@@ -227,8 +219,10 @@ def add_decorator_imports(function_to_optimize, code_context):
227219
# write to file
228220
with open(file_path, "w", encoding="utf-8") as file:
229221
file.write(modified_code)
230-
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
231-
file_contents = function_to_optimize.file_path.read_text("utf-8")
232-
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
233-
function_to_optimize.file_path.write_text(modified_code,"utf-8")
234-
return line_profile_output_file
222+
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
223+
with open(function_to_optimize.file_path) as f:
224+
file_contents = f.read()
225+
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
226+
with open(function_to_optimize.file_path, "w") as f:
227+
f.write(modified_code)
228+
return line_profile_output_file

0 commit comments

Comments
 (0)