Skip to content

Commit c783c21

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
update demo scripts to use .ptd
Summary: Add support for .ptd Differential Revision: D70363616
1 parent a0c0d2b commit c783c21

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

extension/training/examples/XOR/export_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
1717
from torch.export import export
1818
from torch.export.experimental import _export_forward_backward
19+
from executorch.exir import ExecutorchBackendConfig
1920

2021

21-
def _export_model():
22+
def _export_model(external_mutable_weights: bool = False):
2223
net = TrainingNet(Net())
2324
x = torch.randn(1, 2)
2425

@@ -30,7 +31,11 @@ def _export_model():
3031
# Lower the graph to edge dialect.
3132
ep = to_edge(ep)
3233
# Lower the graph to executorch.
33-
ep = ep.to_executorch()
34+
ep = ep.to_executorch(
35+
config=ExecutorchBackendConfig(
36+
external_mutable_weights=external_mutable_weights
37+
)
38+
)
3439
return ep
3540

3641

@@ -44,19 +49,30 @@ def main() -> None:
4449
"--outdir",
4550
type=str,
4651
required=True,
47-
help="Path to the directory to write xor.pte files to",
52+
help="Path to the directory to write xor.pte and xor.ptd files to",
53+
)
54+
parser.add_argument(
55+
"--external",
56+
action="store_true",
57+
help="Export the model with external weights",
4858
)
4959
args = parser.parse_args()
5060

51-
ep = _export_model()
61+
ep = _export_model(args.external)
5262

5363
# Write out the .pte file.
5464
os.makedirs(args.outdir, exist_ok=True)
5565
outfile = os.path.join(args.outdir, "xor.pte")
5666
with open(outfile, "wb") as fp:
57-
fp.write(
58-
ep.buffer,
67+
ep.write_to_file(fp)
68+
69+
if args.external:
70+
print("DUMPING")
71+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
72+
ep._tensor_data["xor"] = ep._tensor_data.pop(
73+
"_default_external_constant"
5974
)
75+
ep.write_tensor_data_to_file(args.outdir)
6076

6177

6278
if __name__ == "__main__":

extension/training/examples/XOR/train.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ using executorch::extension::training::optimizer::SGDOptions;
2323
using executorch::runtime::Error;
2424
using executorch::runtime::Result;
2525
DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
26+
DEFINE_string(ptd_path, "", "Model weights serialized in flatbuffer format.");
2627

2728
int main(int argc, char** argv) {
2829
gflags::ParseCommandLineFlags(&argc, &argv, true);
29-
if (argc != 1) {
30+
if (argc == 0) {
31+
ET_LOG(Error, "Please provide a model path.");
32+
return 1;
33+
} else if (argc > 2) {
3034
std::string msg = "Extra commandline args: ";
31-
for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
35+
for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */; i < argc; i++) {
3236
msg += argv[i];
3337
}
3438
ET_LOG(Error, "%s", msg.c_str());
@@ -46,7 +50,20 @@ int main(int argc, char** argv) {
4650
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
4751
std::move(loader_res.get()));
4852

49-
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
53+
std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr;
54+
if (!FLAGS_ptd_path.empty()) {
55+
executorch::runtime::Result<executorch::extension::FileDataLoader>
56+
ptd_loader_res =
57+
executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str());
58+
if (ptd_loader_res.error() != Error::Ok) {
59+
ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str());
60+
return 1;
61+
}
62+
ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
63+
std::move(ptd_loader_res.get()));
64+
}
65+
66+
auto mod = executorch::extension::training::TrainingModule(std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader));
5067

5168
// Create full data set of input and labels.
5269
std::vector<std::pair<
@@ -70,7 +87,7 @@ int main(int argc, char** argv) {
7087
// Get the params and names
7188
auto param_res = mod.named_parameters("forward");
7289
if (param_res.error() != Error::Ok) {
73-
ET_LOG(Error, "Failed to get named parameters");
90+
ET_LOG(Error, "Failed to get named parameters, error: %d", static_cast<int>(param_res.error()));
7491
return 1;
7592
}
7693

@@ -112,5 +129,5 @@ int main(int argc, char** argv) {
112129
std::string(param.first.data()), param.second});
113130
}
114131

115-
executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16);
132+
executorch::extension::flat_tensor::save_ptd("trained_xor.ptd", param_map, 16);
116133
}

0 commit comments

Comments
 (0)