|
3 | 3 | import re |
4 | 4 | import sqlite3 |
5 | 5 | import textwrap |
| 6 | +from collections.abc import Generator |
6 | 7 | from pathlib import Path |
7 | 8 | from typing import TYPE_CHECKING, Any |
8 | 9 |
|
@@ -79,153 +80,127 @@ def create_trace_replay_test_code( |
79 | 80 | A string containing the test code |
80 | 81 |
|
81 | 82 | """ |
82 | | - assert test_framework in ["pytest", "unittest"] |
83 | | - |
84 | | - # Create Imports |
85 | | - imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle |
86 | | -{"import unittest" if test_framework == "unittest" else ""} |
87 | | -from codeflash.benchmarking.replay_test import get_next_arg_and_return |
88 | | -""" |
| 83 | + assert test_framework in ("pytest", "unittest") |
| 84 | + # Build imports |
| 85 | + imports = [ |
| 86 | + "from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle", |
| 87 | + "import unittest" if test_framework == "unittest" else "", |
| 88 | + "from codeflash.benchmarking.replay_test import get_next_arg_and_return", |
| 89 | + ] |
89 | 90 |
|
90 | 91 | function_imports = [] |
| 92 | + get_alias = get_function_alias # avoid attribute lookup in loop |
| 93 | + |
| 94 | + append_func_import = function_imports.append |
| 95 | + |
| 96 | + # BUILD function imports (string join at the end!) |
91 | 97 | for func in functions_data: |
92 | | - module_name = func.get("module_name") |
93 | | - function_name = func.get("function_name") |
| 98 | + module_name = func["module_name"] |
| 99 | + function_name = func["function_name"] |
94 | 100 | class_name = func.get("class_name", "") |
95 | 101 | if class_name: |
96 | | - function_imports.append( |
97 | | - f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}" |
98 | | - ) |
| 102 | + # Only alias imports once per unique combo (rely on LRU cache) |
| 103 | + append_func_import(f"from {module_name} import {class_name} as {get_alias(module_name, class_name)}") |
99 | 104 | else: |
100 | | - function_imports.append( |
101 | | - f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}" |
102 | | - ) |
| 105 | + append_func_import(f"from {module_name} import {function_name} as {get_alias(module_name, function_name)}") |
103 | 106 |
|
104 | | - imports += "\n".join(function_imports) |
| 107 | + imports.append("\n".join(function_imports)) |
105 | 108 |
|
| 109 | + # Build sorted functions_to_optimize (skip __init__) |
106 | 110 | functions_to_optimize = sorted( |
107 | | - {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} |
108 | | - ) |
109 | | - metadata = f"""functions = {functions_to_optimize} |
110 | | -trace_file_path = r"{trace_file}" |
111 | | -""" |
112 | | - # Templates for different types of tests |
113 | | - test_function_body = textwrap.dedent( |
114 | | - """\ |
115 | | - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): |
116 | | - args = pickle.loads(args_pkl) |
117 | | - kwargs = pickle.loads(kwargs_pkl) |
118 | | - ret = {function_name}(*args, **kwargs) |
119 | | - """ |
120 | | - ) |
121 | | - |
122 | | - test_method_body = textwrap.dedent( |
123 | | - """\ |
124 | | - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
125 | | - args = pickle.loads(args_pkl) |
126 | | - kwargs = pickle.loads(kwargs_pkl){filter_variables} |
127 | | - function_name = "{orig_function_name}" |
128 | | - if not args: |
129 | | - raise ValueError("No arguments provided for the method.") |
130 | | - if function_name == "__init__": |
131 | | - ret = {class_name_alias}(*args[1:], **kwargs) |
132 | | - else: |
133 | | - ret = {class_name_alias}{method_name}(*args, **kwargs) |
134 | | - """ |
| 111 | + {func["function_name"] for func in functions_data if func["function_name"] != "__init__"} |
135 | 112 | ) |
| 113 | + metadata = f'functions = {functions_to_optimize}\ntrace_file_path = r"{trace_file}"\n' |
136 | 114 |
|
137 | | - test_class_method_body = textwrap.dedent( |
138 | | - """\ |
139 | | - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
140 | | - args = pickle.loads(args_pkl) |
141 | | - kwargs = pickle.loads(kwargs_pkl){filter_variables} |
142 | | - if not args: |
143 | | - raise ValueError("No arguments provided for the method.") |
144 | | - ret = {class_name_alias}{method_name}(*args[1:], **kwargs) |
145 | | - """ |
146 | | - ) |
147 | | - test_static_method_body = textwrap.dedent( |
148 | | - """\ |
149 | | - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
150 | | - args = pickle.loads(args_pkl) |
151 | | - kwargs = pickle.loads(kwargs_pkl){filter_variables} |
152 | | - ret = {class_name_alias}{method_name}(*args, **kwargs) |
153 | | - """ |
154 | | - ) |
| 115 | + # Pointer to templates |
| 116 | + templates = { |
| 117 | + "function": _test_function_body, |
| 118 | + "method": _test_method_body, |
| 119 | + "classmethod": _test_class_method_body, |
| 120 | + "staticmethod": _test_static_method_body, |
| 121 | + } |
155 | 122 |
|
156 | | - # Create main body |
| 123 | + # Setup for main test generation |
| 124 | + tests = [] |
157 | 125 |
|
158 | 126 | if test_framework == "unittest": |
159 | | - self = "self" |
160 | | - test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" |
| 127 | + test_class_header = "\nclass TestTracedFunctions(unittest.TestCase):\n" |
| 128 | + tests.append(test_class_header) |
| 129 | + indent = " " |
| 130 | + self_arg = "self" |
161 | 131 | else: |
162 | | - test_template = "" |
163 | | - self = "" |
| 132 | + indent = " " |
| 133 | + self_arg = "" |
| 134 | + |
| 135 | + # Inline access for performance |
| 136 | + get_unique_name = get_unique_test_name |
164 | 137 |
|
165 | 138 | for func in functions_data: |
166 | | - module_name = func.get("module_name") |
167 | | - function_name = func.get("function_name") |
| 139 | + module_name = func["module_name"] |
| 140 | + function_name = func["function_name"] |
168 | 141 | class_name = func.get("class_name") |
169 | | - file_path = func.get("file_path") |
170 | | - benchmark_function_name = func.get("benchmark_function_name") |
171 | | - function_properties = func.get("function_properties") |
| 142 | + file_path = func["file_path"] |
| 143 | + benchmark_function_name = func["benchmark_function_name"] |
| 144 | + function_properties = func["function_properties"] |
| 145 | + |
172 | 146 | if not class_name: |
173 | | - alias = get_function_alias(module_name, function_name) |
174 | | - test_body = test_function_body.format( |
| 147 | + alias = get_alias(module_name, function_name) |
| 148 | + test_body = templates["function"].format( |
175 | 149 | benchmark_function_name=benchmark_function_name, |
176 | 150 | orig_function_name=function_name, |
177 | 151 | function_name=alias, |
178 | 152 | file_path=file_path, |
179 | 153 | max_run_count=max_run_count, |
180 | 154 | ) |
181 | 155 | else: |
182 | | - class_name_alias = get_function_alias(module_name, class_name) |
183 | | - alias = get_function_alias(module_name, class_name + "_" + function_name) |
184 | | - |
| 156 | + class_name_alias = get_alias(module_name, class_name) |
| 157 | + alias = get_alias(module_name, class_name + "_" + function_name) |
185 | 158 | filter_variables = "" |
186 | | - # filter_variables = '\n args.pop("cls", None)' |
187 | 159 | method_name = "." + function_name if function_name != "__init__" else "" |
188 | | - if function_properties.is_classmethod: |
189 | | - test_body = test_class_method_body.format( |
| 160 | + if getattr(function_properties, "is_classmethod", False): |
| 161 | + test_body = templates["classmethod"].format( |
190 | 162 | benchmark_function_name=benchmark_function_name, |
191 | 163 | orig_function_name=function_name, |
192 | 164 | file_path=file_path, |
193 | | - class_name_alias=class_name_alias, |
194 | 165 | class_name=class_name, |
| 166 | + class_name_alias=class_name_alias, |
195 | 167 | method_name=method_name, |
196 | 168 | max_run_count=max_run_count, |
197 | 169 | filter_variables=filter_variables, |
198 | 170 | ) |
199 | | - elif function_properties.is_staticmethod: |
200 | | - test_body = test_static_method_body.format( |
| 171 | + elif getattr(function_properties, "is_staticmethod", False): |
| 172 | + test_body = templates["staticmethod"].format( |
201 | 173 | benchmark_function_name=benchmark_function_name, |
202 | 174 | orig_function_name=function_name, |
203 | 175 | file_path=file_path, |
204 | | - class_name_alias=class_name_alias, |
205 | 176 | class_name=class_name, |
| 177 | + class_name_alias=class_name_alias, |
206 | 178 | method_name=method_name, |
207 | 179 | max_run_count=max_run_count, |
208 | 180 | filter_variables=filter_variables, |
209 | 181 | ) |
210 | 182 | else: |
211 | | - test_body = test_method_body.format( |
| 183 | + test_body = templates["method"].format( |
212 | 184 | benchmark_function_name=benchmark_function_name, |
213 | 185 | orig_function_name=function_name, |
214 | 186 | file_path=file_path, |
215 | | - class_name_alias=class_name_alias, |
216 | 187 | class_name=class_name, |
| 188 | + class_name_alias=class_name_alias, |
217 | 189 | method_name=method_name, |
218 | 190 | max_run_count=max_run_count, |
219 | 191 | filter_variables=filter_variables, |
220 | 192 | ) |
| 193 | + # Indent the block only once |
| 194 | + formatted_test_body = textwrap.indent(test_body, indent) |
| 195 | + unique_test_name = get_unique_name(module_name, function_name, benchmark_function_name, class_name) |
| 196 | + # Compose function definition |
| 197 | + if test_framework == "unittest": |
| 198 | + tests.append(f" def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n") |
| 199 | + else: |
| 200 | + tests.append(f"def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n") |
221 | 201 |
|
222 | | - formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") |
223 | | - |
224 | | - test_template += " " if test_framework == "unittest" else "" |
225 | | - unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name) |
226 | | - test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n" |
227 | | - |
228 | | - return imports + "\n" + metadata + "\n" + test_template |
| 202 | + # Final string build (list join for speed) |
| 203 | + return "\n".join(imports) + "\n" + metadata + "\n" + "".join(tests) |
229 | 204 |
|
230 | 205 |
|
231 | 206 | def generate_replay_test( |
@@ -308,3 +283,48 @@ def generate_replay_test( |
308 | 283 | logger.info(f"Error generating replay tests: {e}") |
309 | 284 |
|
310 | 285 | return count |
| 286 | + |
| 287 | + |
| 288 | +_test_function_body = textwrap.dedent( |
| 289 | + """\ |
| 290 | + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): |
| 291 | + args = pickle.loads(args_pkl) |
| 292 | + kwargs = pickle.loads(kwargs_pkl) |
| 293 | + ret = {function_name}(*args, **kwargs) |
| 294 | + """ |
| 295 | +) |
| 296 | + |
| 297 | +_test_method_body = textwrap.dedent( |
| 298 | + """\ |
| 299 | + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
| 300 | + args = pickle.loads(args_pkl) |
| 301 | + kwargs = pickle.loads(kwargs_pkl){filter_variables} |
| 302 | + function_name = "{orig_function_name}" |
| 303 | + if not args: |
| 304 | + raise ValueError("No arguments provided for the method.") |
| 305 | + if function_name == "__init__": |
| 306 | + ret = {class_name_alias}(*args[1:], **kwargs) |
| 307 | + else: |
| 308 | + ret = {class_name_alias}{method_name}(*args, **kwargs) |
| 309 | + """ |
| 310 | +) |
| 311 | + |
| 312 | +_test_class_method_body = textwrap.dedent( |
| 313 | + """\ |
| 314 | + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
| 315 | + args = pickle.loads(args_pkl) |
| 316 | + kwargs = pickle.loads(kwargs_pkl){filter_variables} |
| 317 | + if not args: |
| 318 | + raise ValueError("No arguments provided for the method.") |
| 319 | + ret = {class_name_alias}{method_name}(*args[1:], **kwargs) |
| 320 | + """ |
| 321 | +) |
| 322 | + |
| 323 | +_test_static_method_body = textwrap.dedent( |
| 324 | + """\ |
| 325 | + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): |
| 326 | + args = pickle.loads(args_pkl) |
| 327 | + kwargs = pickle.loads(kwargs_pkl){filter_variables} |
| 328 | + ret = {class_name_alias}{method_name}(*args, **kwargs) |
| 329 | + """ |
| 330 | +) |
0 commit comments