diff --git a/extension/pt2_archive/targets.bzl b/extension/pt2_archive/targets.bzl index a52b37d53ba..033cc69e5bf 100644 --- a/extension/pt2_archive/targets.bzl +++ b/extension/pt2_archive/targets.bzl @@ -8,6 +8,27 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ + runtime.python_binary( + name = "export", + main_module = "executorch.extension.pt2_archive.test.pt2_archive_export", + srcs = ["test/pt2_archive_export.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/_serialize:lib", + ], + visibility = [], # Private + ) + + runtime.genrule( + name = "gen_pt2_archive", + cmd = "$(exe :export) --outdir $OUT", + outs = { + "model": ["model.pt2"], + }, + default_outs = ["."], + ) + runtime.cxx_library( name = "pt2_archive_data_map", srcs = [ @@ -60,7 +81,7 @@ def define_common_targets(): "//executorch/runtime/platform:platform", ], env = { - "TEST_LINEAR_PT2": "$(location :linear)", + "TEST_LINEAR_PT2": "$(location :gen_pt2_archive[model])", "ET_MODULE_LINEAR_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])", }, # Not available for mobile with miniz and json dependencies. diff --git a/extension/pt2_archive/test/pt2_archive_export.py b/extension/pt2_archive/test/pt2_archive_export.py new file mode 100644 index 00000000000..f611259d670 --- /dev/null +++ b/extension/pt2_archive/test/pt2_archive_export.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import argparse +import os + +import torch +from executorch.exir import ExecutorchBackendConfig, to_edge + +from torch.export import ExportedProgram +from torch.export.pt2_archive._package import package_pt2 + + +class ModuleLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_random_inputs(self): + return (torch.randn(3),) + + +def main() -> None: + torch.manual_seed(0) + parser = argparse.ArgumentParser() + parser.add_argument( + "--outdir", + type=str, + required=True, + help="Path to the directory to write model.pt2 files to", + ) + args = parser.parse_args() + + m = ModuleLinear() + sample_inputs = m.get_random_inputs() + ep = torch.export.export(m, sample_inputs) + + # Lower to ExecuTorch + exec_prog = to_edge(ep).to_executorch( + ExecutorchBackendConfig(external_constants=True) + ) + + if not isinstance(ep, ExportedProgram): + raise TypeError( + f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." + ) + + # Create PT2 archive file + os.makedirs(args.outdir, exist_ok=True) + filename = os.path.join(args.outdir, "model.pt2") + package_pt2( + filename, + exported_programs={"model": ep}, + executorch_files={"model.pte": exec_prog.buffer}, + ) + + +if __name__ == "__main__": + main()