Skip to content

Commit 26c736e

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Training demo (#5445)
Summary: Allows the XOR model training demo to be runnable in OSS. Will follow up with a documentation PR about training and how to run this demo. Im sure my cmakelist.txt changes have issues so if anyone sees ways to improve them please let me know. Only hack I had to do was the optimizer was calling an ET op directly. I don't think we have enabled this in OSS yet so I will follow up with larryliu0820 when hes back and in the meantime open up an issue. Repro of demo: Pull Request resolved: #5445 Test Plan: python3 extension/training/examples/XOR/export_model.py --outdir /tmp/xor rm -rf cmake-out mkdir cmake-out cmake \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DPYTHON_EXECUTABLE=python \ -Bcmake-out . cmake --build cmake-out -j9 --target install --config Release ./cmake-out/extension/training/train_xor --model_path=/tmp/xor/xor.pte Reviewed By: dvorjackz Differential Revision: D62905840 Pulled By: JacobSzwejbka fbshipit-source-id: 622e68637ee7a0bb1b323e777d60e9516be115cd
1 parent 53c1a5f commit 26c736e

File tree

12 files changed

+198
-84
lines changed

12 files changed

+198
-84
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"
183183

184184
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "Build the Tensor extension" OFF)
185185

186+
option(EXECUTORCH_BUILD_EXTENSION_TRAINING "Build the training extension" OFF)
187+
186188
option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)
187189

188190
option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)
@@ -636,6 +638,10 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
636638
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
637639
endif()
638640

641+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
642+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
643+
endif()
644+
639645
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
640646
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
641647
endif()

build/Utils.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ function(executorch_print_configuration_summary)
6868
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TENSOR : "
6969
"${EXECUTORCH_BUILD_EXTENSION_TENSOR}"
7070
)
71+
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TRAINING : "
72+
"${EXECUTORCH_BUILD_EXTENSION_TRAINING}"
73+
)
7174
message(
7275
STATUS
7376
" EXECUTORCH_BUILD_FLATC : ${EXECUTORCH_BUILD_FLATC}"

build/cmake_deps.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,34 @@ deps = [
210210
"executorch",
211211
"executorch_no_prim_ops",
212212
]
213+
214+
[targets.extension_training]
215+
buck_targets = [
216+
"//extension/training/module:training_module",
217+
"//extension/training/optimizer:sgd",
218+
]
219+
filters = [
220+
".cpp$",
221+
]
222+
deps = [
223+
"executorch_no_prim_ops",
224+
]
225+
226+
[targets.train_xor]
227+
buck_targets = [
228+
"//extension/training/examples/XOR:train_xor",
229+
]
230+
filters = [
231+
".cpp$",
232+
]
233+
excludes = [
234+
"^codegen",
235+
]
236+
deps = [
237+
"executorch",
238+
"executorch_no_prim_ops",
239+
"portable_kernels",
240+
]
213241
# ---------------------------------- extension end ----------------------------------
214242
# ---------------------------------- binary start ----------------------------------
215243

build/executorch-config.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(lib_list
4848
extension_runner_util
4949
extension_tensor
5050
extension_threadpool
51+
extension_training
5152
xnnpack_backend
5253
XNNPACK
5354
cpuinfo

extension/training/CMakeLists.txt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
17+
endif()
18+
19+
list(TRANSFORM _extension_training__srcs PREPEND "${EXECUTORCH_ROOT}/")
20+
21+
add_library(extension_training ${_extension_training__srcs})
22+
target_include_directories(
23+
extension_training PUBLIC ${_common_include_directories}
24+
)
25+
26+
target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
27+
target_compile_options(extension_training PUBLIC ${_common_compile_options})
28+
target_link_libraries(extension_training executorch_no_prim_ops
29+
extension_data_loader extension_module extension_tensor)
30+
31+
32+
list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")
33+
add_executable(train_xor ${_train_xor__srcs})
34+
target_include_directories(
35+
train_xor PUBLIC ${_common_include_directories}
36+
)
37+
target_link_libraries(
38+
train_xor gflags executorch_no_prim_ops portable_ops_lib extension_tensor
39+
extension_training program_schema
40+
)
41+
target_compile_options(train_xor PUBLIC ${_common_compile_options})
42+
43+
# Install libraries
44+
install(
45+
TARGETS extension_training
46+
DESTINATION lib
47+
INCLUDES
48+
DESTINATION ${_common_include_directories}
49+
)

extension/training/__init__.py

Whitespace-only changes.
Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,8 @@
11
# Any targets that should be shared between fbcode and xplat must be defined in
22
# targets.bzl. This file can contain fbcode-only targets.
33

4-
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
5-
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
64
load(":targets.bzl", "define_common_targets")
75

86
oncall("executorch")
97

108
define_common_targets()
11-
12-
python_library(
13-
name = "model",
14-
srcs = ["model.py"],
15-
visibility = [], # Private
16-
deps = [
17-
"//caffe2:torch",
18-
],
19-
)
20-
21-
python_library(
22-
name = "export_model_lib",
23-
srcs = ["export_model_lib.py"],
24-
visibility = [],
25-
deps = [
26-
":model",
27-
"//caffe2:torch",
28-
"//executorch/exir:lib",
29-
],
30-
)
31-
32-
python_binary(
33-
name = "export_model",
34-
main_function = ".export_model.main",
35-
main_src = "export_model.py",
36-
deps = [
37-
":export_model_lib",
38-
"//caffe2:torch",
39-
],
40-
)

extension/training/examples/XOR/export_model.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
import argparse
1010

11+
import os
12+
1113
import torch
14+
from executorch.exir import to_edge
1215

13-
from .export_model_lib import export_model
16+
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
17+
from torch.export._trace import _export
18+
from torch.export.experimental import _export_forward_backward
1419

1520

1621
def main() -> None:
@@ -26,7 +31,27 @@ def main() -> None:
2631
help="Path to the directory to write xor.pte files to",
2732
)
2833
args = parser.parse_args()
29-
export_model(args.outdir)
34+
35+
net = TrainingNet(Net())
36+
x = torch.randn(1, 2)
37+
38+
# Captures the forward graph. The graph will look similar to the model definition now.
39+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40+
ep = _export(net, (x, torch.ones(1, dtype=torch.int64)), pre_dispatch=True)
41+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42+
ep = _export_forward_backward(ep)
43+
# Lower the graph to edge dialect.
44+
ep = to_edge(ep)
45+
# Lower the graph to executorch.
46+
ep = ep.to_executorch()
47+
48+
# Write out the .pte file.
49+
os.makedirs(args.outdir, exist_ok=True)
50+
outfile = os.path.join(args.outdir, "xor.pte")
51+
with open(outfile, "wb") as fp:
52+
fp.write(
53+
ep.buffer,
54+
)
3055

3156

3257
if __name__ == "__main__":

extension/training/examples/XOR/targets.bzl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,31 @@ def define_common_targets():
2121
external_deps = ["gflags"],
2222
define_static_target = True,
2323
)
24+
25+
runtime.python_library(
26+
name = "model",
27+
srcs = ["model.py"],
28+
visibility = [], # Private
29+
deps = [
30+
"//caffe2:torch",
31+
],
32+
)
33+
34+
runtime.python_library(
35+
name = "export_model_lib",
36+
srcs = ["export_model_lib.py", "export_model.py"],
37+
visibility = [],
38+
deps = [
39+
":model",
40+
"//caffe2:torch",
41+
"//executorch/exir:lib",
42+
],
43+
)
44+
45+
runtime.python_binary(
46+
name = "export_model",
47+
main_module = "executorch.extension.training.examples.XOR.export_model",
48+
deps = [
49+
":export_model_lib",
50+
],
51+
)

extension/training/examples/XOR/train.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@ int main(int argc, char** argv) {
5454
data_set;
5555
data_set.push_back( // XOR(1, 1) = 0
5656
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),
57-
executorch::extension::make_tensor_ptr<long>({1}, {0})});
57+
executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
5858
data_set.push_back( // XOR(0, 0) = 0
5959
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),
60-
executorch::extension::make_tensor_ptr<long>({1}, {0})});
60+
executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
6161
data_set.push_back( // XOR(1, 0) = 1
6262
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),
63-
executorch::extension::make_tensor_ptr<long>({1}, {1})});
63+
executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
6464
data_set.push_back( // XOR(0, 1) = 1
6565
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),
66-
executorch::extension::make_tensor_ptr<long>({1}, {1})});
66+
executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
6767

6868
// Create optimizer.
6969
// Get the params and names

0 commit comments

Comments
 (0)