diff --git a/examples/sdk/sdk_example_runner/sdk_example_runner.py b/examples/sdk/sdk_example_runner/sdk_example_runner.py new file mode 100644 index 00000000000..447fb633270 --- /dev/null +++ b/examples/sdk/sdk_example_runner/sdk_example_runner.py @@ -0,0 +1,185 @@ +import argparse +import contextlib +import os +import subprocess +import tempfile +from typing import List, Tuple, Union + +import torch +from executorch.exir import ExecutorchProgramManager, to_edge +from executorch.exir.tracer import Value +from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite + +from executorch.sdk.bundled_program.core import create_bundled_program +from executorch.sdk.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) +from executorch.sdk.inspector import Inspector +from executorch.sdk.inspector._inspector_utils import compare_results +from torch.export import export + +@contextlib.contextmanager +def change_directory(path: str): + # record cwd (current working directory) + cwd = os.getcwd() + try: + os.chdir(path) + yield os.getcwd() + finally: + # restore cwd + os.chdir(cwd) + + +def run_command(command): + command = command.split() + try: + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + print("Command failed with return code", e.returncode) + print("Output:") + print(e.output.decode()) + else: + for line in result.stdout.decode().split("\n"): + print(line) + + +# A simple model for a test case. +class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, arg): + return self.linear(arg) + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + return (torch.randn(3, 3),) + + +# Builds the sdk_example_runner and returns the path to it. +def build_executor_runner(executorch_root_dir): + with change_directory(executorch_root_dir): + # Clean any existing cmake caches and configure cmake to get ready for a build. + # run_command("rm -rf cmake-out") + run_command("mkdir cmake-out") + run_command("cd cmake-out") + run_command( + "cmake -DBUCK2=buck2 -DEXECUTORCH_BUILD_SDK=1 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=1 -B cmake-out ." + ) + + # Build the sdk_example_runner + run_command("cmake --build cmake-out -j8 -t sdk_example_runner") + + # Return the path to the sdk_example_runner binary. + return "cmake-out/examples/sdk/sdk_example_runner" + + +# Take in an eager mode model and convert it to an ExecuTorch program. +def export_to_exec_prog( + model: Union[torch.fx.GraphModule, torch.nn.Module], + example_inputs: Tuple[Value, ...], +) -> ExecutorchProgramManager: + model.eval() + core_aten_ep = export(model, example_inputs) + edge_manager = to_edge(core_aten_ep) + executorch_manager = edge_manager.to_executorch() + return executorch_manager + + +# Take in an ExecuTorch program and bundle along some input test cases with it +# to produce a bundled program. This bundled program can be consumed by the +# sdk_example_runner to run the model along with the bundled input test cases. +def generate_bundled_program(executorch_program, model, example_inputs): + method_test_suites: List[MethodTestSuite] = [] + method_test_cases: List[MethodTestCase] = [] + + method_test_cases = [MethodTestCase(inputs=example_inputs)] + + method_test_suites.append( + MethodTestSuite( + method_name="forward", + test_cases=method_test_cases, + ) + ) + + bundled_program = create_bundled_program(executorch_program, method_test_suites) + bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( + bundled_program + ) + + return bundled_program_buffer + + +# Runs the sdk_example_runner on the given bundled program and returns the paths +# to the etdump and debug_output files generated by the sdk_example_runner. These +# will be used later for comparison against the outputs of the eager mode model. +def run_and_generate_etdump(executorch_root_dir, working_dir_name, binary_path, bundled_program_buffer): + with change_directory(executorch_root_dir): + bundled_program_path = f"{working_dir_name}/bundled_program.pt" + f = open(bundled_program_path, "wb") + f.write(bundled_program_buffer) + f.close() + + etdump_path = f"{working_dir_name}/etdump.etdp" + debug_output_path = f"{working_dir_name}/debug_output.bin" + + cmd = f"{binary_path} --bundled_program_path {bundled_program_path} --etdump_path {etdump_path} --debug_output_path {debug_output_path} --dump_outputs" + print(f"Running executor runner: {cmd}") + run_command(cmd) + + return etdump_path, debug_output_path + + +# Takes in the etdump file and the debug_output file generated by the sdk_example_runner +# and compares the outputs of the eager mode model and the executorch program using the +# Inspector API's that are explained here in more detail. +# https://pytorch.org/executorch/main/sdk-inspector.html +def verify_outputs(etdump_path, debug_output_path, model, example_inputs): + inspector = Inspector(etdump_path=etdump_path, debug_buffer_path=debug_output_path) + for event_block in inspector.event_blocks: + if event_block.name == "Execute": + # Disable gradient computation since we are only interested in verifying outputs. + with torch.no_grad(): + model.eval() + ref_output = model(*example_inputs) + + # If the output is a single tensor then convert it into a list for convenience. + if isinstance(ref_output, torch.Tensor): + ref_output = [ref_output] + + # Compare the outputs of the eager mode and the executorch program. + # This function will return three stats SNR, MSE, and cosine similarity. + # For a model that is performing weel SNR should be as high as possible, + # MSE should be as close to zero as possible, and cosine similarity should + # be as close to one as possible. + compare_results( + reference_output=ref_output, + run_output=event_block.run_output, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--executorch_root_dir", + required=True, + help="Set the path to the root of the executorch repo directory.", + ) + args = parser.parse_args() + + model = TestModel() + example_inputs = model.get_example_inputs() + + exec_prog = export_to_exec_prog(model, example_inputs) + bundled_program_buffer = generate_bundled_program(exec_prog, model, example_inputs) + binary_path = build_executor_runner(args.executorch_root_dir) + with tempfile.TemporaryDirectory() as tmpdirname: + etdump_path, debug_output_path = run_and_generate_etdump( + args.executorch_root_dir, tmpdirname, binary_path, bundled_program_buffer + ) + verify_outputs(etdump_path, debug_output_path, model, example_inputs)