Skip to content

Commit 96d95b5

Browse files
committed
Generate pt2 archive with pte file
Differential Revision: [D81992612](https://our.internmc.facebook.com/intern/diff/D81992612/) ghstack-source-id: 308410221 Pull Request resolved: #14124
1 parent 6f89131 commit 96d95b5

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

extension/pt2_archive/targets.bzl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,27 @@ def define_common_targets():
88
TARGETS and BUCK files that call this function.
99
"""
1010

11+
runtime.python_binary(
12+
name = "export",
13+
main_module = "executorch.extension.pt2_archive.test.pt2_archive_export",
14+
srcs = ["test/pt2_archive_export.py"],
15+
deps = [
16+
"//caffe2:torch",
17+
"//executorch/exir:lib",
18+
"//executorch/exir/_serialize:lib",
19+
],
20+
visibility = [], # Private
21+
)
22+
23+
runtime.genrule(
24+
name = "gen_pt2_archive",
25+
cmd = "$(exe :export) --outdir $OUT",
26+
outs = {
27+
"model": ["model.pt2"],
28+
},
29+
default_outs = ["."],
30+
)
31+
1132
runtime.cxx_library(
1233
name = "pt2_archive_data_map",
1334
srcs = [
@@ -60,7 +81,7 @@ def define_common_targets():
6081
"//executorch/runtime/platform:platform",
6182
],
6283
env = {
63-
"TEST_LINEAR_PT2": "$(location :linear)",
84+
"TEST_LINEAR_PT2": "$(location :gen_pt2_archive[model])",
6485
"ET_MODULE_LINEAR_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
6586
},
6687
# Not available for mobile with miniz and json dependencies.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
import os
11+
12+
import torch
13+
from executorch.exir import ExecutorchBackendConfig, to_edge
14+
15+
from torch.export import ExportedProgram
16+
from torch.export.pt2_archive._package import package_pt2
17+
18+
19+
class ModuleLinear(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
self.linear = torch.nn.Linear(3, 3)
23+
24+
def forward(self, x: torch.Tensor):
25+
return self.linear(x)
26+
27+
def get_random_inputs(self):
28+
return (torch.randn(3),)
29+
30+
31+
def main() -> None:
32+
torch.manual_seed(0)
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument(
35+
"--outdir",
36+
type=str,
37+
required=True,
38+
help="Path to the directory to write model.pt2 files to",
39+
)
40+
args = parser.parse_args()
41+
42+
m = ModuleLinear()
43+
sample_inputs = m.get_random_inputs()
44+
ep = torch.export.export(m, sample_inputs)
45+
46+
# Lower to ExecuTorch
47+
exec_prog = to_edge(ep).to_executorch(
48+
ExecutorchBackendConfig(external_constants=True)
49+
)
50+
51+
if not isinstance(ep, ExportedProgram):
52+
raise TypeError(
53+
f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead."
54+
)
55+
56+
# Create PT2 archive file
57+
os.makedirs(args.outdir, exist_ok=True)
58+
filename = os.path.join(args.outdir, "model.pt2")
59+
package_pt2(
60+
filename,
61+
exported_programs={"model": ep},
62+
executorch_files={"model.pte": exec_prog.buffer},
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)