diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index f925f19d8..121416352 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -3,6 +3,7 @@ import re import sqlite3 import textwrap +from collections.abc import Generator from pathlib import Path from typing import TYPE_CHECKING, Any @@ -79,99 +80,72 @@ def create_trace_replay_test_code( A string containing the test code """ - assert test_framework in ["pytest", "unittest"] - - # Create Imports - imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle -{"import unittest" if test_framework == "unittest" else ""} -from codeflash.benchmarking.replay_test import get_next_arg_and_return -""" + assert test_framework in ("pytest", "unittest") + # Build imports + imports = [ + "from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle", + "import unittest" if test_framework == "unittest" else "", + "from codeflash.benchmarking.replay_test import get_next_arg_and_return", + ] function_imports = [] + get_alias = get_function_alias # avoid attribute lookup in loop + + append_func_import = function_imports.append + + # BUILD function imports (string join at the end!) for func in functions_data: - module_name = func.get("module_name") - function_name = func.get("function_name") + module_name = func["module_name"] + function_name = func["function_name"] class_name = func.get("class_name", "") if class_name: - function_imports.append( - f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}" - ) + # Only alias imports once per unique combo (rely on LRU cache) + append_func_import(f"from {module_name} import {class_name} as {get_alias(module_name, class_name)}") else: - function_imports.append( - f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}" - ) + append_func_import(f"from {module_name} import {function_name} as {get_alias(module_name, function_name)}") - imports += "\n".join(function_imports) + imports.append("\n".join(function_imports)) + # Build sorted functions_to_optimize (skip __init__) functions_to_optimize = sorted( - {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} - ) - metadata = f"""functions = {functions_to_optimize} -trace_file_path = r"{trace_file}" -""" - # Templates for different types of tests - test_function_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl) - ret = {function_name}(*args, **kwargs) - """ - ) - - test_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - function_name = "{orig_function_name}" - if not args: - raise ValueError("No arguments provided for the method.") - if function_name == "__init__": - ret = {class_name_alias}(*args[1:], **kwargs) - else: - ret = {class_name_alias}{method_name}(*args, **kwargs) - """ + {func["function_name"] for func in functions_data if func["function_name"] != "__init__"} ) + metadata = f'functions = {functions_to_optimize}\ntrace_file_path = r"{trace_file}"\n' - test_class_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - if not args: - raise ValueError("No arguments provided for the method.") - ret = {class_name_alias}{method_name}(*args[1:], **kwargs) - """ - ) - test_static_method_body = textwrap.dedent( - """\ - for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): - args = pickle.loads(args_pkl) - kwargs = pickle.loads(kwargs_pkl){filter_variables} - ret = {class_name_alias}{method_name}(*args, **kwargs) - """ - ) + # Pointer to templates + templates = { + "function": _test_function_body, + "method": _test_method_body, + "classmethod": _test_class_method_body, + "staticmethod": _test_static_method_body, + } - # Create main body + # Setup for main test generation + tests = [] if test_framework == "unittest": - self = "self" - test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" + test_class_header = "\nclass TestTracedFunctions(unittest.TestCase):\n" + tests.append(test_class_header) + indent = " " + self_arg = "self" else: - test_template = "" - self = "" + indent = " " + self_arg = "" + + # Inline access for performance + get_unique_name = get_unique_test_name for func in functions_data: - module_name = func.get("module_name") - function_name = func.get("function_name") + module_name = func["module_name"] + function_name = func["function_name"] class_name = func.get("class_name") - file_path = func.get("file_path") - benchmark_function_name = func.get("benchmark_function_name") - function_properties = func.get("function_properties") + file_path = func["file_path"] + benchmark_function_name = func["benchmark_function_name"] + function_properties = func["function_properties"] + if not class_name: - alias = get_function_alias(module_name, function_name) - test_body = test_function_body.format( + alias = get_alias(module_name, function_name) + test_body = templates["function"].format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, function_name=alias, @@ -179,53 +153,54 @@ def create_trace_replay_test_code( max_run_count=max_run_count, ) else: - class_name_alias = get_function_alias(module_name, class_name) - alias = get_function_alias(module_name, class_name + "_" + function_name) - + class_name_alias = get_alias(module_name, class_name) + alias = get_alias(module_name, class_name + "_" + function_name) filter_variables = "" - # filter_variables = '\n args.pop("cls", None)' method_name = "." + function_name if function_name != "__init__" else "" - if function_properties.is_classmethod: - test_body = test_class_method_body.format( + if getattr(function_properties, "is_classmethod", False): + test_body = templates["classmethod"].format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, file_path=file_path, - class_name_alias=class_name_alias, class_name=class_name, + class_name_alias=class_name_alias, method_name=method_name, max_run_count=max_run_count, filter_variables=filter_variables, ) - elif function_properties.is_staticmethod: - test_body = test_static_method_body.format( + elif getattr(function_properties, "is_staticmethod", False): + test_body = templates["staticmethod"].format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, file_path=file_path, - class_name_alias=class_name_alias, class_name=class_name, + class_name_alias=class_name_alias, method_name=method_name, max_run_count=max_run_count, filter_variables=filter_variables, ) else: - test_body = test_method_body.format( + test_body = templates["method"].format( benchmark_function_name=benchmark_function_name, orig_function_name=function_name, file_path=file_path, - class_name_alias=class_name_alias, class_name=class_name, + class_name_alias=class_name_alias, method_name=method_name, max_run_count=max_run_count, filter_variables=filter_variables, ) + # Indent the block only once + formatted_test_body = textwrap.indent(test_body, indent) + unique_test_name = get_unique_name(module_name, function_name, benchmark_function_name, class_name) + # Compose function definition + if test_framework == "unittest": + tests.append(f" def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n") + else: + tests.append(f"def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n") - formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") - - test_template += " " if test_framework == "unittest" else "" - unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name) - test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n" - - return imports + "\n" + metadata + "\n" + test_template + # Final string build (list join for speed) + return "\n".join(imports) + "\n" + metadata + "\n" + "".join(tests) def generate_replay_test( @@ -308,3 +283,48 @@ def generate_replay_test( logger.info(f"Error generating replay tests: {e}") return count + + +_test_function_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = {function_name}(*args, **kwargs) + """ +) + +_test_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + function_name = "{orig_function_name}" + if not args: + raise ValueError("No arguments provided for the method.") + if function_name == "__init__": + ret = {class_name_alias}(*args[1:], **kwargs) + else: + ret = {class_name_alias}{method_name}(*args, **kwargs) + """ +) + +_test_class_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + if not args: + raise ValueError("No arguments provided for the method.") + ret = {class_name_alias}{method_name}(*args[1:], **kwargs) + """ +) + +_test_static_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + ret = {class_name_alias}{method_name}(*args, **kwargs) + """ +)