Skip to content

Commit 6ed8873

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add training readme
Summary: Add first pass at documentation bypass-github-export-checks Reviewed By: Jack-Khuu Differential Revision: D62926334 fbshipit-source-id: cd0b3898dc7c2b7bd41489b55d6b24cb8afb25f7
1 parent d2a38cc commit 6ed8873

File tree

4 files changed

+280
-43
lines changed

4 files changed

+280
-43
lines changed

extension/training/README.md

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# ExecuTorch On-device Training
2+
3+
This subtree contains infrastructure to facilitate on-device training using ExecuTorch.
4+
This feature is experimental and under heavy active development, all the APIs are
5+
subject to change and many things may not work out of the box or at all in the
6+
current state.
7+
8+
## Layout
9+
- `examples/` : Example end to end flows from model definition to optimizer.step()
10+
- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training.
11+
- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned.
12+
- `test/`: Tests that cover multiple subdirs.
13+
14+
## Technical Birds Eye view
15+
16+
At a high level ExecuTorch training follows a similar flow to inference with a few extra steps.
17+
18+
Instead of relying on autograd at runtime to dynamically generate the backward graph and then walk it,
19+
we capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as
20+
letting backends have more direct control over more of the model execution. Currently the optimizer is not
21+
captured though this may change over time.
22+
23+
Loss functions must be embedded inside the model definition (and be the first output) this is used during
24+
capture to generate the backwards graph.
25+
26+
Gradients become explicit graph outputs rather then hidden tensor state.
27+
28+
Since the weights now need to be mutable during execution, they are memory planned ahead of time and copied
29+
from the .pte into the HeirarchicalAllocator arenas during Method init.
30+
31+
Integration with backends/delegates is still a work in progress.
32+
33+
34+
## End to End Example
35+
36+
To further understand the features of ExecuTorch Training and how to leverage it,
37+
consider the following end to end example with a neural network learning the XOR function.
38+
39+
### Lowering a joint-graph model to ExecuTorch
40+
41+
After following the [setting up ExecuTorch] guide. You can run
42+
43+
```bash
44+
python3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar
45+
```
46+
to generate the model file. Below is a walkthrough of how that script works.
47+
48+
First lets define our model.
49+
```python
50+
import torch.nn as nn
51+
from torch.nn import functional as F
52+
53+
from torch.export import export
54+
from torch.export.experimental import _export_forward_backward
55+
56+
57+
# Basic Net for XOR
58+
class Net(nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
self.linear = nn.Linear(2, 10)
62+
self.linear2 = nn.Linear(10, 2)
63+
64+
def forward(self, x):
65+
return self.linear2(F.sigmoid(self.linear(x)))
66+
```
67+
68+
The first big difference from the normal ExecuTorch flow is that for training we must embed
69+
the loss function into model and return the loss as our first output.
70+
71+
We don't want to modify the original model definition so we will just wrap it.
72+
73+
```python
74+
class TrainingNet(nn.Module):
75+
def __init__(self, net):
76+
super().__init__()
77+
self.net = net
78+
self.loss = nn.CrossEntropyLoss()
79+
80+
def forward(self, input, label):
81+
pred = self.net(input)
82+
return self.loss(pred, label), pred.detach().argmax(dim=1)
83+
```
84+
85+
Now that we have our model we can lower it to ExecuTorch. To do that we just have to follow
86+
a few simple steps.
87+
88+
```python
89+
net = TrainingNet(Net())
90+
91+
# Create our inputs, only the shapes of these matter.
92+
input = torch.randn(1, 2)
93+
label = torch.ones(1, dtype=torch.int64)
94+
95+
# Captures the forward graph. The graph will look similar to the model definition now.
96+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
97+
ep = export(net, (input, label))
98+
```
99+
100+
This is what the graph looks like after export
101+
```python
102+
>>>print(ep.graph_module.graph)
103+
104+
graph():
105+
%p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
106+
%p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
107+
%p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
108+
%p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
109+
%input : [num_users=1] = placeholder[target=input]
110+
%label : [num_users=1] = placeholder[target=label]
111+
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {})
112+
%sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})
113+
%linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {})
114+
%cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {})
115+
%detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {})
116+
%argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {})
117+
return (cross_entropy_loss, argmax)
118+
```
119+
120+
It should look pretty similar to our model's forward function. Now we need to capture the backwards graph.
121+
122+
```python
123+
ep = _export_forward_backward(ep)
124+
```
125+
126+
and now the graph is
127+
128+
```python
129+
>>>print(ep.graph_module.graph)
130+
131+
graph():
132+
%p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]
133+
%p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]
134+
%p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]
135+
%p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]
136+
%input : [num_users=2] = placeholder[target=input]
137+
%label : [num_users=5] = placeholder[target=label]
138+
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {})
139+
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {})
140+
%sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})
141+
%alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {})
142+
%alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {})
143+
%permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {})
144+
%addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {})
145+
%_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {})
146+
%alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {})
147+
%alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {})
148+
%ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
149+
%scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
150+
%where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {})
151+
%unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})
152+
%gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {})
153+
%squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {})
154+
%neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
155+
%ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
156+
%scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
157+
%where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {})
158+
%ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})
159+
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {})
160+
%_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu})
161+
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {})
162+
%div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {})
163+
%alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {})
164+
%alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {})
165+
%alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {})
166+
%argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {})
167+
%full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
168+
%div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {})
169+
%unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {})
170+
%ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
171+
%scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})
172+
%where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {})
173+
%full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format})
174+
%scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {})
175+
%ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})
176+
%scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
177+
%where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {})
178+
%mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {})
179+
%alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {})
180+
%alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {})
181+
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {})
182+
%sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {})
183+
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {})
184+
%sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {})
185+
%permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {})
186+
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {})
187+
%permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {})
188+
%mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {})
189+
%permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {})
190+
%sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {})
191+
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {})
192+
%permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {})
193+
%alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {})
194+
%alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {})
195+
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {})
196+
%mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {})
197+
%mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {})
198+
%permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {})
199+
%mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {})
200+
%permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {})
201+
%sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {})
202+
%view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {})
203+
%permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {})
204+
return (div, argmax, permute_8, view_1, permute_5, view)
205+
```
206+
207+
Its a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph
208+
alongside the forward and now our model returns [Loss, Any other user outputs, Gradients].
209+
210+
From here we can lower the rest of the way to ExecuTorch
211+
```python
212+
ep = to_edge(ep)
213+
214+
# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make
215+
# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also
216+
# embedded in the model __et_training_gradients_index_<method_name>,
217+
# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>.
218+
#
219+
# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient.
220+
ep = ep.to_executorch()
221+
222+
with open("xor.pte", "wb") as file:
223+
ep.write_to_file(file)
224+
```
225+
226+
### Run the model train script with CMAKE
227+
After exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such:
228+
```bash
229+
# cd to the root of executorch repo
230+
cd executorch
231+
232+
# Get a clean cmake-out directory
233+
rm -rf cmake-out
234+
mkdir cmake-out
235+
236+
# Configure cmake
237+
cmake \
238+
-DCMAKE_INSTALL_PREFIX=cmake-out \
239+
-DCMAKE_BUILD_TYPE=Release \
240+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
241+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
242+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
243+
-DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \
244+
-DEXECUTORCH_ENABLE_LOGGING=ON \
245+
-DPYTHON_EXECUTABLE=python \
246+
-Bcmake-out .
247+
```
248+
Then you can build the runtime componenets with
249+
250+
```bash
251+
cmake --build cmake-out -j9 --target install --config Release
252+
```
253+
254+
Now you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such
255+
```bash
256+
./cmake-out/extension/training/train_xor --model_path=./xor.pte
257+
```
258+
259+
## What is missing?/ What is next?
260+
A ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview.
261+
262+
The _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs.
263+
264+
The ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs.
265+
266+
We don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs).
267+
268+
We plan to add a way to update models in place on-device (will be needed for finetuning).
269+
270+
We are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training.
271+
272+
and so much more!
273+
274+
## Help & Improvements
275+
If you have problems or questions, or have suggestions for ways to make
276+
implementation and testing better, please reach out to the PyTorch Edge team or
277+
create an issue on [github](https://www.github.com/pytorch/executorch/issues).

extension/training/examples/XOR/export_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir import to_edge
1515

1616
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
17-
from torch.export._trace import _export
17+
from torch.export import export
1818
from torch.export.experimental import _export_forward_backward
1919

2020

@@ -37,7 +37,7 @@ def main() -> None:
3737

3838
# Captures the forward graph. The graph will look similar to the model definition now.
3939
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40-
ep = _export(net, (x, torch.ones(1, dtype=torch.int64)), pre_dispatch=True)
40+
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
4141
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
4242
ep = _export_forward_backward(ep)
4343
# Lower the graph to edge dialect.

extension/training/examples/XOR/export_model_lib.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

extension/training/examples/XOR/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def define_common_targets():
3333

3434
runtime.python_library(
3535
name = "export_model_lib",
36-
srcs = ["export_model_lib.py", "export_model.py"],
36+
srcs = ["export_model.py"],
3737
visibility = [],
3838
deps = [
3939
":model",

0 commit comments

Comments
 (0)