diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index c2e1889db..3f901b8ad 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -1,7 +1,6 @@ from __future__ import annotations import sqlite3 -import textwrap from pathlib import Path from typing import TYPE_CHECKING, Any @@ -68,94 +67,99 @@ def create_trace_replay_test_code( """ assert test_framework in ["pytest", "unittest"] - # Create Imports - imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle -{"import unittest" if test_framework == "unittest" else ""} -from codeflash.benchmarking.replay_test import get_next_arg_and_return -""" + # Precompute all needed values up-front for efficiency + unittest_import = "import unittest" if test_framework == "unittest" else "" + imports = ( + "from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle\n" + f"{unittest_import}\n" + "from codeflash.benchmarking.replay_test import get_next_arg_and_return\n" + ) function_imports = [] + functions_to_optimize = set() + + # Collect imports and test function names in one pass: for func in functions_data: - module_name = func.get("module_name") - function_name = func.get("function_name") - class_name = func.get("class_name", "") + module_name = func["module_name"] + function_name = func["function_name"] + class_name = func.get("class_name") if class_name: - function_imports.append( - f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}" - ) + alias = get_function_alias(module_name, class_name) + function_imports.append(f"from {module_name} import {class_name} as {alias}") else: - function_imports.append( - f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}" - ) - + alias = get_function_alias(module_name, function_name) + function_imports.append(f"from {module_name} import {function_name} as {alias}") + if function_name != "__init__": + functions_to_optimize.add(function_name) imports += "\n".join(function_imports) - functions_to_optimize = sorted( - {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} - ) - metadata = f"""functions = {functions_to_optimize} -trace_file_path = r"{trace_file}" -""" - # Templates for different types of tests - test_function_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl) - ret = {function_name}(*args, **kwargs) - """ - ) + metadata = f'functions = {sorted(functions_to_optimize)}\ntrace_file_path = r"{trace_file}"\n' - test_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - function_name = "{orig_function_name}" - if not args: - raise ValueError("No arguments provided for the method.") - if function_name == "__init__": - ret = {class_name_alias}(*args[1:], **kwargs) - else: - ret = {class_name_alias}{method_name}(*args, **kwargs) - """ + # Templates, dedented once for speed + test_function_body = ( + "for args_pkl, kwargs_pkl in get_next_arg_and_return(" + 'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", ' + 'function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):\n' + " args = pickle.loads(args_pkl)\n" + " kwargs = pickle.loads(kwargs_pkl)\n" + " ret = {function_name}(*args, **kwargs)\n" ) - - test_class_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - if not args: - raise ValueError("No arguments provided for the method.") - ret = {class_name_alias}{method_name}(*args[1:], **kwargs) - """ + test_method_body = ( + "for args_pkl, kwargs_pkl in get_next_arg_and_return(" + 'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", ' + 'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n' + " args = pickle.loads(args_pkl)\n" + " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n" + ' function_name = "{orig_function_name}"\n' + " if not args:\n" + ' raise ValueError("No arguments provided for the method.")\n' + ' if function_name == "__init__":\n' + " ret = {class_name_alias}(*args[1:], **kwargs)\n" + " else:\n" + " ret = {class_name_alias}{method_name}(*args, **kwargs)\n" ) - test_static_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - ret = {class_name_alias}{method_name}(*args, **kwargs) - """ + test_class_method_body = ( + "for args_pkl, kwargs_pkl in get_next_arg_and_return(" + 'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", ' + 'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n' + " args = pickle.loads(args_pkl)\n" + " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n" + " if not args:\n" + ' raise ValueError("No arguments provided for the method.")\n' + " ret = {class_name_alias}{method_name}(*args[1:], **kwargs)\n" + ) + test_static_method_body = ( + "for args_pkl, kwargs_pkl in get_next_arg_and_return(" + 'trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", ' + 'function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):\n' + " args = pickle.loads(args_pkl)\n" + " kwargs = pickle.loads(kwargs_pkl){filter_variables}\n" + " ret = {class_name_alias}{method_name}(*args, **kwargs)\n" ) - - # Create main body if test_framework == "unittest": - self = "self" - test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" + self_arg = "self" + test_header = "\nclass TestTracedFunctions(unittest.TestCase):\n" + def_indent = " " + body_indent = " " else: - test_template = "" - self = "" + self_arg = "" + test_header = "" + def_indent = "" + body_indent = " " + + # String builder technique for fast test template construction + test_template_lines = [test_header] + append = test_template_lines.append # local variable for speed for func in functions_data: - module_name = func.get("module_name") - function_name = func.get("function_name") + module_name = func["module_name"] + function_name = func["function_name"] class_name = func.get("class_name") - file_path = func.get("file_path") - benchmark_function_name = func.get("benchmark_function_name") - function_properties = func.get("function_properties") + file_path = func["file_path"] + benchmark_function_name = func["benchmark_function_name"] + function_properties = func["function_properties"] + if not class_name: alias = get_function_alias(module_name, function_name) test_body = test_function_body.format( @@ -168,9 +172,7 @@ def create_trace_replay_test_code( else: class_name_alias = get_function_alias(module_name, class_name) alias = get_function_alias(module_name, class_name + "_" + function_name) - filter_variables = "" - # filter_variables = '\n args.pop("cls", None)' method_name = "." + function_name if function_name != "__init__" else "" if function_properties.is_classmethod: test_body = test_class_method_body.format( @@ -206,12 +208,14 @@ def create_trace_replay_test_code( filter_variables=filter_variables, ) - formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") - - test_template += " " if test_framework == "unittest" else "" - test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n" + # Manually indent for speed (no textwrap.indent) + test_body_indented = "".join( + body_indent + ln if ln else body_indent for ln in test_body.splitlines(keepends=True) + ) + append(f"{def_indent}def test_{alias}({self_arg}):\n{test_body_indented}\n") - return imports + "\n" + metadata + "\n" + test_template + # Final string concatenation + return f"{imports}\n{metadata}\n{''.join(test_template_lines)}" def generate_replay_test(