Skip to content

Commit 7add233

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
update demo scripts to use .ptd (#8834)
Summary: Pull Request resolved: #8834 Add support for .ptd Differential Revision: D70363616
1 parent 781b082 commit 7add233

File tree

4 files changed

+57
-16
lines changed

4 files changed

+57
-16
lines changed

CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,15 @@ cmake_dependent_option(
248248
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
249249
)
250250

251-
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
251+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
252252
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
253+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
254+
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
255+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
253256
endif()
254257

255-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
256-
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
258+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
257259
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
258-
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
259260
endif()
260261

261262
if(EXECUTORCH_BUILD_EXTENSION_MODULE)

extension/training/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ target_include_directories(
2626
target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
2727
target_compile_options(extension_training PUBLIC ${_common_compile_options})
2828
target_link_libraries(extension_training executorch_core
29-
extension_data_loader extension_module extension_tensor)
29+
extension_data_loader extension_module extension_tensor extension_flat_tensor)
3030

3131

3232
list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")

extension/training/examples/XOR/export_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
1717
from torch.export import export
1818
from torch.export.experimental import _export_forward_backward
19+
from executorch.exir import ExecutorchBackendConfig
1920

2021

21-
def _export_model():
22+
def _export_model(external_mutable_weights: bool = False):
2223
net = TrainingNet(Net())
2324
x = torch.randn(1, 2)
2425

@@ -30,7 +31,11 @@ def _export_model():
3031
# Lower the graph to edge dialect.
3132
ep = to_edge(ep)
3233
# Lower the graph to executorch.
33-
ep = ep.to_executorch()
34+
ep = ep.to_executorch(
35+
config=ExecutorchBackendConfig(
36+
external_mutable_weights=external_mutable_weights
37+
)
38+
)
3439
return ep
3540

3641

@@ -44,19 +49,30 @@ def main() -> None:
4449
"--outdir",
4550
type=str,
4651
required=True,
47-
help="Path to the directory to write xor.pte files to",
52+
help="Path to the directory to write xor.pte and xor.ptd files to",
53+
)
54+
parser.add_argument(
55+
"--external",
56+
action="store_true",
57+
help="Export the model with external weights",
4858
)
4959
args = parser.parse_args()
5060

51-
ep = _export_model()
61+
ep = _export_model(args.external)
5262

5363
# Write out the .pte file.
5464
os.makedirs(args.outdir, exist_ok=True)
5565
outfile = os.path.join(args.outdir, "xor.pte")
5666
with open(outfile, "wb") as fp:
57-
fp.write(
58-
ep.buffer,
67+
ep.write_to_file(fp)
68+
69+
if args.external:
70+
print("DUMPING")
71+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
72+
ep._tensor_data["xor"] = ep._tensor_data.pop(
73+
"_default_external_constant"
5974
)
75+
ep.write_tensor_data_to_file(args.outdir)
6076

6177

6278
if __name__ == "__main__":

extension/training/examples/XOR/train.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions;
2323
using executorch::runtime::Error;
2424
using executorch::runtime::Result;
2525
DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
26+
DEFINE_string(ptd_path, "", "Model weights serialized in flatbuffer format.");
2627

2728
int main(int argc, char** argv) {
2829
gflags::ParseCommandLineFlags(&argc, &argv, true);
29-
if (argc != 1) {
30+
if (argc == 0) {
31+
ET_LOG(Error, "Please provide a model path.");
32+
return 1;
33+
} else if (argc > 2) {
3034
std::string msg = "Extra commandline args: ";
31-
for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
35+
for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */;
36+
i < argc;
37+
i++) {
3238
msg += argv[i];
3339
}
3440
ET_LOG(Error, "%s", msg.c_str());
@@ -46,7 +52,21 @@ int main(int argc, char** argv) {
4652
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
4753
std::move(loader_res.get()));
4854

49-
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
55+
std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr;
56+
if (!FLAGS_ptd_path.empty()) {
57+
executorch::runtime::Result<executorch::extension::FileDataLoader>
58+
ptd_loader_res =
59+
executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str());
60+
if (ptd_loader_res.error() != Error::Ok) {
61+
ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str());
62+
return 1;
63+
}
64+
ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
65+
std::move(ptd_loader_res.get()));
66+
}
67+
68+
auto mod = executorch::extension::training::TrainingModule(
69+
std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader));
5070

5171
// Create full data set of input and labels.
5272
std::vector<std::pair<
@@ -70,7 +90,10 @@ int main(int argc, char** argv) {
7090
// Get the params and names
7191
auto param_res = mod.named_parameters("forward");
7292
if (param_res.error() != Error::Ok) {
73-
ET_LOG(Error, "Failed to get named parameters");
93+
ET_LOG(
94+
Error,
95+
"Failed to get named parameters, error: %d",
96+
static_cast<int>(param_res.error()));
7497
return 1;
7598
}
7699

@@ -112,5 +135,6 @@ int main(int argc, char** argv) {
112135
std::string(param.first.data()), param.second});
113136
}
114137

115-
executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16);
138+
executorch::extension::flat_tensor::save_ptd(
139+
"trained_xor.ptd", param_map, 16);
116140
}

0 commit comments

Comments
 (0)