Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 113 additions & 93 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import sqlite3
import textwrap
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -79,153 +80,127 @@ def create_trace_replay_test_code(
A string containing the test code

"""
assert test_framework in ["pytest", "unittest"]

# Create Imports
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""
assert test_framework in ("pytest", "unittest")
# Build imports
imports = [
"from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle",
"import unittest" if test_framework == "unittest" else "",
"from codeflash.benchmarking.replay_test import get_next_arg_and_return",
]

function_imports = []
get_alias = get_function_alias # avoid attribute lookup in loop

append_func_import = function_imports.append

# BUILD function imports (string join at the end!)
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
module_name = func["module_name"]
function_name = func["function_name"]
class_name = func.get("class_name", "")
if class_name:
function_imports.append(
f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}"
)
# Only alias imports once per unique combo (rely on LRU cache)
append_func_import(f"from {module_name} import {class_name} as {get_alias(module_name, class_name)}")
else:
function_imports.append(
f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}"
)
append_func_import(f"from {module_name} import {function_name} as {get_alias(module_name, function_name)}")

imports += "\n".join(function_imports)
imports.append("\n".join(function_imports))

# Build sorted functions_to_optimize (skip __init__)
functions_to_optimize = sorted(
{func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"}
)
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
"""
# Templates for different types of tests
test_function_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = {function_name}(*args, **kwargs)
"""
)

test_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
function_name = "{orig_function_name}"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = {class_name_alias}(*args[1:], **kwargs)
else:
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
{func["function_name"] for func in functions_data if func["function_name"] != "__init__"}
)
metadata = f'functions = {functions_to_optimize}\ntrace_file_path = r"{trace_file}"\n'

test_class_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
if not args:
raise ValueError("No arguments provided for the method.")
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
"""
)
test_static_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
)
# Pointer to templates
templates = {
"function": _test_function_body,
"method": _test_method_body,
"classmethod": _test_class_method_body,
"staticmethod": _test_static_method_body,
}

# Create main body
# Setup for main test generation
tests = []

if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
test_class_header = "\nclass TestTracedFunctions(unittest.TestCase):\n"
tests.append(test_class_header)
indent = " "
self_arg = "self"
else:
test_template = ""
self = ""
indent = " "
self_arg = ""

# Inline access for performance
get_unique_name = get_unique_test_name

for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
module_name = func["module_name"]
function_name = func["function_name"]
class_name = func.get("class_name")
file_path = func.get("file_path")
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
file_path = func["file_path"]
benchmark_function_name = func["benchmark_function_name"]
function_properties = func["function_properties"]

if not class_name:
alias = get_function_alias(module_name, function_name)
test_body = test_function_body.format(
alias = get_alias(module_name, function_name)
test_body = templates["function"].format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
function_name=alias,
file_path=file_path,
max_run_count=max_run_count,
)
else:
class_name_alias = get_function_alias(module_name, class_name)
alias = get_function_alias(module_name, class_name + "_" + function_name)

class_name_alias = get_alias(module_name, class_name)
alias = get_alias(module_name, class_name + "_" + function_name)
filter_variables = ""
# filter_variables = '\n args.pop("cls", None)'
method_name = "." + function_name if function_name != "__init__" else ""
if function_properties.is_classmethod:
test_body = test_class_method_body.format(
if getattr(function_properties, "is_classmethod", False):
test_body = templates["classmethod"].format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
class_name_alias=class_name_alias,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
elif function_properties.is_staticmethod:
test_body = test_static_method_body.format(
elif getattr(function_properties, "is_staticmethod", False):
test_body = templates["staticmethod"].format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
class_name_alias=class_name_alias,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
else:
test_body = test_method_body.format(
test_body = templates["method"].format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
class_name_alias=class_name_alias,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
# Indent the block only once
formatted_test_body = textwrap.indent(test_body, indent)
unique_test_name = get_unique_name(module_name, function_name, benchmark_function_name, class_name)
# Compose function definition
if test_framework == "unittest":
tests.append(f" def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n")
else:
tests.append(f"def test_{unique_test_name}({self_arg}):\n{formatted_test_body}\n")

formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")

test_template += " " if test_framework == "unittest" else ""
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"

return imports + "\n" + metadata + "\n" + test_template
# Final string build (list join for speed)
return "\n".join(imports) + "\n" + metadata + "\n" + "".join(tests)


def generate_replay_test(
Expand Down Expand Up @@ -308,3 +283,48 @@ def generate_replay_test(
logger.info(f"Error generating replay tests: {e}")

return count


_test_function_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = {function_name}(*args, **kwargs)
"""
)

_test_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
function_name = "{orig_function_name}"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = {class_name_alias}(*args[1:], **kwargs)
else:
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
)

_test_class_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
if not args:
raise ValueError("No arguments provided for the method.")
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
"""
)

_test_static_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
)
Loading