Skip to content

Commit d6f58cd

Browse files
committed
extension/training builds in OSS
1 parent 3befc8a commit d6f58cd

File tree

11 files changed

+137
-64
lines changed

11 files changed

+137
-64
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"
183183

184184
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "Build the Tensor extension" OFF)
185185

186+
option(EXECUTORCH_BUILD_EXTENSION_TRAINING "Build the training extension" OFF)
187+
186188
option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)
187189

188190
option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)
@@ -636,6 +638,10 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
636638
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
637639
endif()
638640

641+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
642+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
643+
endif()
644+
639645
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
640646
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
641647
endif()

build/Utils.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ function(executorch_print_configuration_summary)
6868
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TENSOR : "
6969
"${EXECUTORCH_BUILD_EXTENSION_TENSOR}"
7070
)
71+
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TRAINING : "
72+
"${EXECUTORCH_BUILD_EXTENSION_TRAINING}"
73+
)
7174
message(
7275
STATUS
7376
" EXECUTORCH_BUILD_FLATC : ${EXECUTORCH_BUILD_FLATC}"

build/cmake_deps.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,31 @@ deps = [
210210
"executorch",
211211
"executorch_no_prim_ops",
212212
]
213+
214+
[targets.extension_training]
215+
buck_targets = [
216+
"//extension/training/module:training_module",
217+
"//extension/training/optimizer:sgd",
218+
]
219+
filters = [
220+
".cpp$",
221+
]
222+
deps = [
223+
"executorch",
224+
"portable_kernels",
225+
]
226+
227+
[targets.train_xor]
228+
buck_targets = [
229+
"//extension/training/examples/XOR:train_xor",
230+
]
231+
filters = [
232+
".cpp$",
233+
]
234+
deps = [
235+
"executorch",
236+
"portable_kernels",
237+
]
213238
# ---------------------------------- extension end ----------------------------------
214239
# ---------------------------------- binary start ----------------------------------
215240

build/executorch-config.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(lib_list
4848
extension_runner_util
4949
extension_tensor
5050
extension_threadpool
51+
extension_training
5152
xnnpack_backend
5253
XNNPACK
5354
cpuinfo

extension/training/CMakeLists.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
17+
endif()
18+
19+
list(TRANSFORM _extension_training__srcs PREPEND "${EXECUTORCH_ROOT}/")
20+
message(FOOBAR="${_extension_training__srcs}")
21+
add_library(extension_training ${_extension_training__srcs})
22+
target_link_libraries(extension_training executorch_no_prim_ops)
23+
target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
24+
target_compile_options(extension_training PUBLIC ${_common_compile_options})
25+
26+
list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")
27+
add_executable(train_xor ${_train_xor__srcs})
28+
target_link_libraries(
29+
train_xor gflags portable_ops_lib
30+
)
31+
target_compile_options(train_xor PUBLIC ${_common_compile_options})
32+
33+
# Install libraries
34+
install(
35+
TARGETS extension_training
36+
DESTINATION lib
37+
INCLUDES
38+
DESTINATION ${_common_include_directories}
39+
)

extension/training/__init__.py

Whitespace-only changes.
Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,8 @@
11
# Any targets that should be shared between fbcode and xplat must be defined in
22
# targets.bzl. This file can contain fbcode-only targets.
33

4-
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
5-
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
64
load(":targets.bzl", "define_common_targets")
75

86
oncall("executorch")
97

108
define_common_targets()
11-
12-
python_library(
13-
name = "model",
14-
srcs = ["model.py"],
15-
visibility = [], # Private
16-
deps = [
17-
"//caffe2:torch",
18-
],
19-
)
20-
21-
python_library(
22-
name = "export_model_lib",
23-
srcs = ["export_model_lib.py"],
24-
visibility = [],
25-
deps = [
26-
":model",
27-
"//caffe2:torch",
28-
"//executorch/exir:lib",
29-
],
30-
)
31-
32-
python_binary(
33-
name = "export_model",
34-
main_function = ".export_model.main",
35-
main_src = "export_model.py",
36-
deps = [
37-
":export_model_lib",
38-
"//caffe2:torch",
39-
],
40-
)

extension/training/examples/XOR/export_model.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88

99
import argparse
1010

11+
import os
12+
1113
import torch
14+
from executorch.exir import to_edge
1215

13-
from .export_model_lib import export_model
16+
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
17+
from torch.export._trace import _export
18+
from torch.export.experimental import _export_forward_backward
1419

1520

1621
def main() -> None:
@@ -26,7 +31,27 @@ def main() -> None:
2631
help="Path to the directory to write xor.pte files to",
2732
)
2833
args = parser.parse_args()
29-
export_model(args.outdir)
34+
35+
net = TrainingNet(Net())
36+
x = torch.randn(1, 2)
37+
38+
# Captures the forward graph. The graph will look similar to the model definition now.
39+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40+
ep = _export(net, (x, torch.ones(1, dtype=torch.int64)), pre_dispatch=True)
41+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42+
ep = _export_forward_backward(ep)
43+
# Lower the graph to edge dialect.
44+
ep = to_edge(ep)
45+
# Lower the graph to executorch.
46+
ep = ep.to_executorch()
47+
48+
# Write out the .pte file.
49+
os.makedirs(args.outdir, exist_ok=True)
50+
outfile = os.path.join(args.outdir, "xor.pte")
51+
with open(outfile, "wb") as fp:
52+
fp.write(
53+
ep.buffer,
54+
)
3055

3156

3257
if __name__ == "__main__":

extension/training/examples/XOR/targets.bzl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,31 @@ def define_common_targets():
2121
external_deps = ["gflags"],
2222
define_static_target = True,
2323
)
24+
25+
runtime.python_library(
26+
name = "model",
27+
srcs = ["model.py"],
28+
visibility = [], # Private
29+
deps = [
30+
"//caffe2:torch",
31+
],
32+
)
33+
34+
runtime.python_library(
35+
name = "export_model_lib",
36+
srcs = ["export_model_lib.py", "export_model.py"],
37+
visibility = [],
38+
deps = [
39+
":model",
40+
"//caffe2:torch",
41+
"//executorch/exir:lib",
42+
],
43+
)
44+
45+
runtime.python_binary(
46+
name = "export_model",
47+
main_module = "executorch.extension.training.examples.XOR.export_model",
48+
deps = [
49+
":export_model_lib",
50+
],
51+
)

extension/training/optimizer/sgd.cpp

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77
*/
88

99
#include <executorch/extension/training/optimizer/sgd.h>
10-
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/portable/NativeFunctions.h>
1111

1212
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1414

1515
using exec_aten::Tensor;
1616
using exec_aten::TensorImpl;
1717
using ::executorch::runtime::Error;
18-
using ::executorch::runtime::KernelRuntimeContext;
1918

2019
namespace executorch {
2120
namespace extension {
@@ -73,10 +72,7 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
7372
auto p = param_iter->second;
7473
if (weight_decay != 0) {
7574
// uses weight_decay specified and adds it to the gradient
76-
torch::executor::aten::add_outf(context, d_p, p, weight_decay, d_p);
77-
if (context.failure_state() != Error::Ok) {
78-
return context.failure_state();
79-
}
75+
torch::executor::native::add_out(d_p, p, weight_decay, d_p);
8076
}
8177
if (momentum != 0) {
8278
Tensor buf(nullptr);
@@ -100,11 +96,8 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
10096
const_cast<TensorImpl::DimOrderType*>(d_p.dim_order().data()));
10197
buf = Tensor(buf_impl);
10298
#endif
103-
torch::executor::aten::clone_outf(
104-
context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
105-
if (context.failure_state() != Error::Ok) {
106-
return context.failure_state();
107-
}
99+
torch::executor::native::clone_out(
100+
d_p, exec_aten::MemoryFormat::Contiguous, buf);
108101

109102
// save the state of the momentum buffer to be reused in later
110103
// epochs
@@ -115,31 +108,18 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
115108
.momentum_buffer();
116109

117110
// update the momentum buffer and apply dampening
118-
torch::executor::aten::mul_outf(context, buf, momentum, buf);
119-
if (context.failure_state() != Error::Ok) {
120-
return context.failure_state();
121-
}
122-
torch::executor::aten::add_outf(
123-
context, buf, d_p, 1 - dampening, buf);
124-
if (context.failure_state() != Error::Ok) {
125-
return context.failure_state();
126-
}
111+
torch::executor::native::mul_out(context, buf, momentum, buf);
112+
torch::executor::native::add_out(buf, d_p, 1 - dampening, buf);
127113
}
128114
if (nesterov) {
129115
// apply nesterov momentum
130-
torch::executor::aten::add_outf(context, d_p, buf, momentum, d_p);
131-
if (context.failure_state() != Error::Ok) {
132-
return context.failure_state();
133-
}
116+
torch::executor::native::add_out(d_p, buf, momentum, d_p);
134117
} else {
135118
d_p = buf;
136119
}
137120
}
138121
// update the parameter using the gradient and learning rate
139-
torch::executor::aten::add_outf(context, p, d_p, -1 * options.lr(), p);
140-
if (context.failure_state() != Error::Ok) {
141-
return context.failure_state();
142-
}
122+
torch::executor::native::add_out(p, d_p, -1 * options.lr(), p);
143123
}
144124
}
145125
}

0 commit comments

Comments
 (0)