11from __future__ import annotations
22
33import sqlite3
4- import textwrap
54from pathlib import Path
65from typing import TYPE_CHECKING , Any
76
@@ -43,6 +42,7 @@ def get_next_arg_and_return(
4342
4443
4544def get_function_alias (module : str , function_name : str ) -> str :
45+ # This is already pretty optimal.
4646 return "_" .join (module .split ("." )) + "_" + function_name
4747
4848
@@ -66,152 +66,144 @@ def create_trace_replay_test_code(
6666 A string containing the test code
6767
6868 """
69- assert test_framework in [ "pytest" , "unittest" ]
69+ assert test_framework in ( "pytest" , "unittest" )
7070
71- # Create Imports
72- imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
73- { "import unittest" if test_framework == "unittest" else "" }
74- from codeflash.benchmarking.replay_test import get_next_arg_and_return
75- """
71+ # Precompute aliases and filepaths
72+ func_aliases , class_aliases , classfunc_aliases , file_paths = _get_aliases_and_paths (functions_data )
7673
74+ # Build function imports in one pass
7775 function_imports = []
7876 for func in functions_data :
7977 module_name = func .get ("module_name" )
8078 function_name = func .get ("function_name" )
8179 class_name = func .get ("class_name" , "" )
8280 if class_name :
83- function_imports .append (
84- f"from { module_name } import { class_name } as { get_function_alias (module_name , class_name )} "
85- )
81+ cname_alias = class_aliases [class_name ]
82+ function_imports .append (f"from { module_name } import { class_name } as { cname_alias } " )
8683 else :
87- function_imports .append (
88- f"from { module_name } import { function_name } as { get_function_alias (module_name , function_name )} "
89- )
90-
91- imports += "\n " .join (function_imports )
92-
93- functions_to_optimize = sorted (
94- {func .get ("function_name" ) for func in functions_data if func .get ("function_name" ) != "__init__" }
84+ alias = func_aliases [(module_name , function_name )]
85+ function_imports .append (f"from { module_name } import { function_name } as { alias } " )
86+ imports = (
87+ "from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle\n "
88+ f"{ 'import unittest' if test_framework == 'unittest' else '' } \n "
89+ "from codeflash.benchmarking.replay_test import get_next_arg_and_return\n " + "\n " .join (function_imports )
9590 )
91+
92+ # Precompute functions_to_optimize efficiently using set and list since sorted(set(...))
93+ functions_set = {func ["function_name" ] for func in functions_data if func ["function_name" ] != "__init__" }
94+ functions_to_optimize = sorted (functions_set )
9695 metadata = f"""functions = { functions_to_optimize }
9796trace_file_path = r"{ trace_file } "
9897"""
99- # Templates for different types of tests
100- test_function_body = textwrap .dedent (
101- """\
102- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
103- args = pickle.loads(args_pkl)
104- kwargs = pickle.loads(kwargs_pkl)
105- ret = {function_name}(*args, **kwargs)
106- """
107- )
10898
109- test_method_body = textwrap .dedent (
110- """\
111- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
112- args = pickle.loads(args_pkl)
113- kwargs = pickle.loads(kwargs_pkl){filter_variables}
114- function_name = "{orig_function_name}"
115- if not args:
116- raise ValueError("No arguments provided for the method.")
117- if function_name == "__init__":
118- ret = {class_name_alias}(*args[1:], **kwargs)
119- else:
120- ret = {class_name_alias}{method_name}(*args, **kwargs)
121- """
99+ # Prepare templates only once
100+ test_function_body = (
101+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
102+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
103+ 'file_path=r"{file_path}", num_to_get={max_run_count}):\n '
104+ " args = pickle.loads(args_pkl)\n "
105+ " kwargs = pickle.loads(kwargs_pkl)\n "
106+ " ret = {function_name}(*args, **kwargs)\n "
122107 )
123-
124- test_class_method_body = textwrap .dedent (
125- """\
126- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
127- args = pickle.loads(args_pkl)
128- kwargs = pickle.loads(kwargs_pkl){filter_variables}
129- if not args:
130- raise ValueError("No arguments provided for the method.")
131- ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
132- """
108+ test_method_body = (
109+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
110+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
111+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
112+ " args = pickle.loads(args_pkl)\n "
113+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
114+ ' function_name = "{orig_function_name}"\n '
115+ " if not args:\n "
116+ ' raise ValueError("No arguments provided for the method.")\n '
117+ ' if function_name == "__init__":\n '
118+ " ret = {class_name_alias}(*args[1:], **kwargs)\n "
119+ " else:\n "
120+ " ret = {class_name_alias}{method_name}(*args, **kwargs)\n "
133121 )
134- test_static_method_body = textwrap .dedent (
135- """\
136- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
137- args = pickle.loads(args_pkl)
138- kwargs = pickle.loads(kwargs_pkl){filter_variables}
139- ret = {class_name_alias}{method_name}(*args, **kwargs)
140- """
122+ test_class_method_body = (
123+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
124+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
125+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
126+ " args = pickle.loads(args_pkl)\n "
127+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
128+ " if not args:\n "
129+ ' raise ValueError("No arguments provided for the method.")\n '
130+ " ret = {class_name_alias}{method_name}(*args[1:], **kwargs)\n "
141131 )
142-
143- # Create main body
144-
132+ test_static_method_body = (
133+ "for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, "
134+ 'benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", '
135+ 'file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n '
136+ " args = pickle.loads(args_pkl)\n "
137+ " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n "
138+ " ret = {class_name_alias}{method_name}(*args, **kwargs)\n "
139+ )
140+ test_bodies = {
141+ "function" : test_function_body ,
142+ "method" : test_method_body ,
143+ "classmethod" : test_class_method_body ,
144+ "staticmethod" : test_static_method_body ,
145+ }
146+
147+ # Precompute the format values up-front for all functions
145148 if test_framework == "unittest" :
146- self = "self"
147- test_template = "\n class TestTracedFunctions(unittest.TestCase):\n "
149+ self_str = "self"
150+ test_template_list = ["\n class TestTracedFunctions(unittest.TestCase):\n " ]
151+ indent_level = " "
152+ def_line = " "
148153 else :
149- test_template = ""
150- self = ""
154+ self_str = ""
155+ test_template_list = []
156+ indent_level = " "
157+ def_line = ""
151158
152159 for func in functions_data :
153- module_name = func . get ( "module_name" )
154- function_name = func . get ( "function_name" )
160+ module_name = func [ "module_name" ]
161+ function_name = func [ "function_name" ]
155162 class_name = func .get ("class_name" )
156- file_path = func .get ("file_path" )
157- benchmark_function_name = func .get ("benchmark_function_name" )
158- function_properties = func .get ("function_properties" )
163+ file_path = func ["file_path" ]
164+ file_path_posix = file_paths [file_path ]
165+ benchmark_function_name = func ["benchmark_function_name" ]
166+ function_properties = func ["function_properties" ]
159167 if not class_name :
160- alias = get_function_alias (module_name , function_name )
161- test_body = test_function_body .format (
168+ alias = func_aliases [(module_name , function_name )]
169+ template = test_bodies ["function" ]
170+ test_body_filled = template .format (
162171 benchmark_function_name = benchmark_function_name ,
163172 orig_function_name = function_name ,
164173 function_name = alias ,
165- file_path = Path ( file_path ). as_posix () ,
174+ file_path = file_path_posix ,
166175 max_run_count = max_run_count ,
167176 )
168177 else :
169- class_name_alias = get_function_alias (module_name , class_name )
170- alias = get_function_alias (module_name , class_name + "_" + function_name )
171-
178+ class_name_alias = class_aliases [class_name ]
179+ alias = classfunc_aliases [(module_name , class_name , function_name )]
172180 filter_variables = ""
173- # filter_variables = '\n args.pop("cls", None)'
174181 method_name = "." + function_name if function_name != "__init__" else ""
175182 if function_properties .is_classmethod :
176- test_body = test_class_method_body .format (
177- benchmark_function_name = benchmark_function_name ,
178- orig_function_name = function_name ,
179- file_path = Path (file_path ).as_posix (),
180- class_name_alias = class_name_alias ,
181- class_name = class_name ,
182- method_name = method_name ,
183- max_run_count = max_run_count ,
184- filter_variables = filter_variables ,
185- )
183+ template = test_bodies ["classmethod" ]
186184 elif function_properties .is_staticmethod :
187- test_body = test_static_method_body .format (
188- benchmark_function_name = benchmark_function_name ,
189- orig_function_name = function_name ,
190- file_path = Path (file_path ).as_posix (),
191- class_name_alias = class_name_alias ,
192- class_name = class_name ,
193- method_name = method_name ,
194- max_run_count = max_run_count ,
195- filter_variables = filter_variables ,
196- )
185+ template = test_bodies ["staticmethod" ]
197186 else :
198- test_body = test_method_body .format (
199- benchmark_function_name = benchmark_function_name ,
200- orig_function_name = function_name ,
201- file_path = Path (file_path ).as_posix (),
202- class_name_alias = class_name_alias ,
203- class_name = class_name ,
204- method_name = method_name ,
205- max_run_count = max_run_count ,
206- filter_variables = filter_variables ,
207- )
187+ template = test_bodies ["method" ]
188+ test_body_filled = template .format (
189+ benchmark_function_name = benchmark_function_name ,
190+ orig_function_name = function_name ,
191+ file_path = file_path_posix ,
192+ class_name_alias = class_name_alias ,
193+ class_name = class_name ,
194+ method_name = method_name ,
195+ max_run_count = max_run_count ,
196+ filter_variables = filter_variables ,
197+ )
208198
209- formatted_test_body = textwrap .indent (test_body , " " if test_framework == "unittest" else " " )
199+ # No repeated indent/dedent. Do indent directly, as we know where to indent.
200+ formatted_test_body = "" .join (
201+ indent_level + line if line .strip () else line for line in test_body_filled .splitlines (True )
202+ )
210203
211- test_template += " " if test_framework == "unittest" else ""
212- test_template += f"def test_{ alias } ({ self } ):\n { formatted_test_body } \n "
204+ test_template_list .append (f"{ def_line } def test_{ alias } ({ self_str } ):\n { formatted_test_body } \n " )
213205
214- return imports + "\n " + metadata + "\n " + test_template
206+ return imports + "\n " + metadata + "\n " + "" . join ( test_template_list )
215207
216208
217209def generate_replay_test (
@@ -294,3 +286,29 @@ def generate_replay_test(
294286 logger .info (f"Error generating replay tests: { e } " )
295287
296288 return count
289+
290+
291+ def _get_aliases_and_paths (functions_data ):
292+ # Precompute all needed aliases and file posix paths up front in a single pass
293+ func_aliases = {}
294+ class_aliases = {}
295+ classfunc_aliases = {}
296+ file_paths = {}
297+ for func in functions_data :
298+ module_name = func .get ("module_name" )
299+ function_name = func .get ("function_name" )
300+ class_name = func .get ("class_name" , "" )
301+ file_path = func .get ("file_path" )
302+ # Precompute Path(file_path).as_posix() once per unique file_path
303+ if file_path not in file_paths :
304+ file_paths [file_path ] = Path (file_path ).as_posix ()
305+ if class_name :
306+ # avoid re-calculating class alias if already done
307+ if class_name not in class_aliases :
308+ class_aliases [class_name ] = get_function_alias (module_name , class_name )
309+ classfunc_key = (module_name , class_name , function_name )
310+ classfunc_aliases [classfunc_key ] = get_function_alias (module_name , class_name + "_" + function_name )
311+ else :
312+ # alias for global function
313+ func_aliases [(module_name , function_name )] = get_function_alias (module_name , function_name )
314+ return func_aliases , class_aliases , classfunc_aliases , file_paths
0 commit comments