Skip to content

Commit e6a9066

Browse files
committed
better test, improve testing
1 parent 3b9bee0 commit e6a9066

File tree

2 files changed

+72
-34
lines changed

2 files changed

+72
-34
lines changed

codeflash/code_utils/lprof_utils.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,38 @@ def leave_FunctionDef(self, original_node, updated_node):
2323
updated_module = module_node.visit(transformer)
2424
return updated_module
2525

26-
def add_decorator_imports(file_paths, fn_list, db_file):
27-
"""Adds a decorator to a function in a Python file."""
28-
for file_path, fn_name in zip(file_paths, fn_list):
29-
#open file
30-
with open(file_path, "r", encoding="utf-8") as file:
31-
file_contents = file.read()
26+
def add_profile_enable(original_code: str, db_file: str) -> str:
27+
module = cst.parse_module(original_code)
28+
found_index = -1
29+
30+
for idx, statement in enumerate(module.body):
31+
if isinstance(statement, cst.SimpleStatementLine):
32+
for stmt in statement.body:
33+
if isinstance(stmt, cst.ImportFrom):
34+
if stmt.module and stmt.module.value == 'line_profiler':
35+
for name in stmt.names:
36+
if isinstance(name, cst.ImportAlias):
37+
if name.name.value == 'profile' and name.asname is None:
38+
found_index = idx
39+
break
40+
if found_index != -1:
41+
break
42+
if found_index != -1:
43+
break
3244

33-
# parse to cst
34-
module_node = cst.parse_module(file_contents)
35-
# add decorator
36-
module_node = add_decorator_cst(module_node, fn_name, 'profile')
37-
# add imports
38-
# Create a transformer to add the import
39-
transformer = ImportAdder("from line_profiler import profile")
45+
if found_index == -1:
46+
return original_code # or raise an exception if the import is not found
4047

41-
# Apply the transformer to add the import
42-
module_node = module_node.visit(transformer)
43-
modified_code = isort.code(module_node.code, float_to_top=True)
44-
# write to file
45-
with open(file_path, "w", encoding="utf-8") as file:
46-
file.write(modified_code)
47-
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
48-
with open(file_paths[0],'r') as f:
49-
file_contents = f.readlines()
50-
for idx, line in enumerate(file_contents):
51-
if 'from line_profiler import profile' in line:
52-
file_contents.insert(idx+1, f"profile.enable(output_prefix='{db_file}')\n")
53-
break
54-
with open(file_paths[0],'w') as f:
55-
f.writelines(file_contents)
48+
# Create the new line to insert
49+
new_line = f"profile.enable(output_prefix='{db_file}')\n"
50+
new_statement = cst.parse_statement(new_line)
5651

52+
# Insert the new statement into the module's body
53+
new_body = list(module.body)
54+
new_body.insert(found_index + 1, new_statement)
55+
modified_module = module.with_changes(body=new_body)
5756

57+
return modified_module.code
5858

5959

6060
class ImportAdder(cst.CSTTransformer):
@@ -82,6 +82,36 @@ def visit_ImportFrom(self, node):
8282
if import_alias.name.value == "profile":
8383
self.has_import = True
8484

85+
86+
def add_decorator_imports(file_paths, fn_list, db_file):
87+
"""Adds a decorator to a function in a Python file."""
88+
for file_path, fn_name in zip(file_paths, fn_list):
89+
#open file
90+
with open(file_path, "r", encoding="utf-8") as file:
91+
file_contents = file.read()
92+
93+
# parse to cst
94+
module_node = cst.parse_module(file_contents)
95+
# add decorator
96+
module_node = add_decorator_cst(module_node, fn_name, 'profile')
97+
# add imports
98+
# Create a transformer to add the import
99+
transformer = ImportAdder("from line_profiler import profile")
100+
101+
# Apply the transformer to add the import
102+
module_node = module_node.visit(transformer)
103+
modified_code = isort.code(module_node.code, float_to_top=True)
104+
# write to file
105+
with open(file_path, "w", encoding="utf-8") as file:
106+
file.write(modified_code)
107+
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
108+
with open(file_paths[0],'r') as f:
109+
file_contents = f.read()
110+
modified_code = add_profile_enable(file_contents,db_file)
111+
with open(file_paths[0],'w') as f:
112+
f.write(modified_code)
113+
114+
85115
def prepare_lprofiler_files(prefix: str = "") -> tuple[Path]:
86116
"""Prepare line profiler output file."""
87117
lprofiler_database_file = get_run_tmp_file(Path(prefix))

tests/test_test_runner.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codeflash.verification.test_results import TestType
99
from codeflash.verification.test_runner import run_behavioral_tests
1010
from codeflash.verification.verification_utils import TestConfig
11+
import dill as pickle
1112

1213

1314
def test_unittest_runner():
@@ -144,11 +145,18 @@ def test_sort():
144145
pytest_target_runtime_seconds=1,
145146
enable_lprofiler=True,
146147
)
147-
with open(tmpdir.name+os.sep+"/baseline.txt", "r") as f:
148-
output = f.read()
149-
expected_output = "Timer unit: 1e-09 s\n\nTotal time: 0 s\nFile: {}\nFunction: sorter at line 4\n\nLine # Hits Time Per Hit % Time Line Contents\n==============================================================\n 4 @profile\n 5 def sorter(arr):\n 6 1 0.0 0.0 arr.sort()\n 7 1 0.0 0.0 return arr\n\n".format(fp.name)
150-
151-
152-
assert output == expected_output, "Test passed"
148+
try:
149+
#todo write a test for the pickle parsing code, right now it's a simplistic test
150+
with open(tmpdir.name+os.sep+"baseline.lprof", "rb") as f:
151+
output = pickle.load(f)
152+
#get lprof instead and compare the hits
153+
output_set = set(list(output.timings.values())[0])
154+
output_set = {(x,y) for x,y,z in output_set}
155+
expected_set = {(6, 1), (7, 1)}
156+
except Exception as e:
157+
print(e)
158+
assert False, "Test failed"
159+
160+
assert expected_set == output_set, "Test passed"
153161
result_file.unlink(missing_ok=True)
154162

0 commit comments

Comments
 (0)