| 
 | 1 | +# ExecuTorch On-device Training  | 
 | 2 | + | 
 | 3 | +This subtree contains infrastructure to facilitate on-device training using ExecuTorch.  | 
 | 4 | +This feature is experimental and under heavy active development, all the APIs are  | 
 | 5 | +subject to change and many things may not work out of the box or at all in the  | 
 | 6 | +current state.  | 
 | 7 | + | 
 | 8 | +## Layout  | 
 | 9 | +- `examples/` : Example end to end flows from model definition to optimizer.step()  | 
 | 10 | +- `module/`: Utility class to provide an improved UX when using ExecuTorch for Training.  | 
 | 11 | +- `optimizer/`: Cpp implementations of various optimizers, currently only SGD though Adam is planned.  | 
 | 12 | +- `test/`: Tests that cover multiple subdirs.  | 
 | 13 | + | 
 | 14 | +## Technical Birds Eye view  | 
 | 15 | + | 
 | 16 | +At a high level ExecuTorch training follows a similar flow to inference with a few extra steps.  | 
 | 17 | + | 
 | 18 | +Instead of relying on autograd at runtime to dynamically generate the backward graph and then walk it,  | 
 | 19 | +we capture the backward graph ahead of time. This lets us be a lot leaner on-device as well as  | 
 | 20 | +letting backends have more direct control over more of the model execution. Currently the optimizer is not  | 
 | 21 | +captured though this may change over time.  | 
 | 22 | + | 
 | 23 | +Loss functions must be embedded inside the model definition (and be the first output) this is used during  | 
 | 24 | +capture to generate the backwards graph.  | 
 | 25 | + | 
 | 26 | +Gradients become explicit graph outputs rather then hidden tensor state.  | 
 | 27 | + | 
 | 28 | +Since the weights now need to be mutable during execution, they are memory planned ahead of time and copied  | 
 | 29 | +from the .pte into the HeirarchicalAllocator arenas during Method init.  | 
 | 30 | + | 
 | 31 | +Integration with backends/delegates is still a work in progress.  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +## End to End Example  | 
 | 35 | + | 
 | 36 | +To further understand the features of ExecuTorch Training and how to leverage it,  | 
 | 37 | +consider the following end to end example with a neural network learning the XOR function.  | 
 | 38 | + | 
 | 39 | +### Lowering a joint-graph model to ExecuTorch  | 
 | 40 | + | 
 | 41 | +After following the [setting up ExecuTorch] guide. You can run  | 
 | 42 | + | 
 | 43 | +```bash  | 
 | 44 | +python3 extension/training/examples/XOR/export_model.py --outdir /tmp/foobar  | 
 | 45 | +```  | 
 | 46 | +to generate the model file. Below is a walkthrough of how that script works.  | 
 | 47 | + | 
 | 48 | +First lets define our model.  | 
 | 49 | +```python  | 
 | 50 | +import torch.nn as nn  | 
 | 51 | +from torch.nn import functional as F  | 
 | 52 | + | 
 | 53 | +from torch.export import export  | 
 | 54 | +from torch.export.experimental import _export_forward_backward  | 
 | 55 | + | 
 | 56 | + | 
 | 57 | +# Basic Net for XOR  | 
 | 58 | +class Net(nn.Module):  | 
 | 59 | +    def __init__(self):  | 
 | 60 | +        super().__init__()  | 
 | 61 | +        self.linear = nn.Linear(2, 10)  | 
 | 62 | +        self.linear2 = nn.Linear(10, 2)  | 
 | 63 | + | 
 | 64 | +    def forward(self, x):  | 
 | 65 | +        return self.linear2(F.sigmoid(self.linear(x)))  | 
 | 66 | +```  | 
 | 67 | + | 
 | 68 | +The first big difference from the normal ExecuTorch flow is that for training we must embed  | 
 | 69 | +the loss function into model and return the loss as our first output.  | 
 | 70 | + | 
 | 71 | +We don't want to modify the original model definition so we will just wrap it.  | 
 | 72 | + | 
 | 73 | +```python  | 
 | 74 | +class TrainingNet(nn.Module):  | 
 | 75 | +    def __init__(self, net):  | 
 | 76 | +        super().__init__()  | 
 | 77 | +        self.net = net  | 
 | 78 | +        self.loss = nn.CrossEntropyLoss()  | 
 | 79 | + | 
 | 80 | +    def forward(self, input, label):  | 
 | 81 | +        pred = self.net(input)  | 
 | 82 | +        return self.loss(pred, label), pred.detach().argmax(dim=1)  | 
 | 83 | +```  | 
 | 84 | + | 
 | 85 | +Now that we have our model we can lower it to ExecuTorch. To do that we just have to follow  | 
 | 86 | +a few simple steps.  | 
 | 87 | + | 
 | 88 | +```python  | 
 | 89 | +net = TrainingNet(Net())  | 
 | 90 | + | 
 | 91 | +# Create our inputs, only the shapes of these matter.  | 
 | 92 | +input = torch.randn(1, 2)  | 
 | 93 | +label = torch.ones(1, dtype=torch.int64)  | 
 | 94 | + | 
 | 95 | +# Captures the forward graph. The graph will look similar to the model definition now.  | 
 | 96 | +# Will move to export_for_training soon which is the api planned to be supported in the long term.  | 
 | 97 | +ep = export(net, (input, label))  | 
 | 98 | +```  | 
 | 99 | + | 
 | 100 | +This is what the graph looks like after export  | 
 | 101 | +```python  | 
 | 102 | +>>>print(ep.graph_module.graph)  | 
 | 103 | + | 
 | 104 | +graph():  | 
 | 105 | +    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]  | 
 | 106 | +    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]  | 
 | 107 | +    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]  | 
 | 108 | +    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]  | 
 | 109 | +    %input : [num_users=1] = placeholder[target=input]  | 
 | 110 | +    %label : [num_users=1] = placeholder[target=label]  | 
 | 111 | +    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%input, %p_net_linear_weight, %p_net_linear_bias), kwargs = {})  | 
 | 112 | +    %sigmoid : [num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%linear,), kwargs = {})  | 
 | 113 | +    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%sigmoid, %p_net_linear2_weight, %p_net_linear2_bias), kwargs = {})  | 
 | 114 | +    %cross_entropy_loss : [num_users=1] = call_function[target=torch.ops.aten.cross_entropy_loss.default](args = (%linear_1, %label), kwargs = {})  | 
 | 115 | +    %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%linear_1,), kwargs = {})  | 
 | 116 | +    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%detach, 1), kwargs = {})  | 
 | 117 | +    return (cross_entropy_loss, argmax)  | 
 | 118 | +```  | 
 | 119 | + | 
 | 120 | +It should look pretty similar to our model's forward function. Now we need to capture the backwards graph.  | 
 | 121 | + | 
 | 122 | +```python  | 
 | 123 | +ep = _export_forward_backward(ep)  | 
 | 124 | +```  | 
 | 125 | + | 
 | 126 | +and now the graph is  | 
 | 127 | + | 
 | 128 | +```python  | 
 | 129 | +>>>print(ep.graph_module.graph)  | 
 | 130 | + | 
 | 131 | +graph():  | 
 | 132 | +    %p_net_linear_weight : [num_users=1] = placeholder[target=p_net_linear_weight]  | 
 | 133 | +    %p_net_linear_bias : [num_users=1] = placeholder[target=p_net_linear_bias]  | 
 | 134 | +    %p_net_linear2_weight : [num_users=1] = placeholder[target=p_net_linear2_weight]  | 
 | 135 | +    %p_net_linear2_bias : [num_users=1] = placeholder[target=p_net_linear2_bias]  | 
 | 136 | +    %input : [num_users=2] = placeholder[target=input]  | 
 | 137 | +    %label : [num_users=5] = placeholder[target=label]  | 
 | 138 | +    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear_weight, [1, 0]), kwargs = {})  | 
 | 139 | +    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear_bias, %input, %permute), kwargs = {})  | 
 | 140 | +    %sigmoid : [num_users=3] = call_function[target=torch.ops.aten.sigmoid.default](args = (%addmm,), kwargs = {})  | 
 | 141 | +    %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%sigmoid,), kwargs = {})  | 
 | 142 | +    %alias_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias,), kwargs = {})  | 
 | 143 | +    %permute_1 : [num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%p_net_linear2_weight, [1, 0]), kwargs = {})  | 
 | 144 | +    %addmm_1 : [num_users=2] = call_function[target=torch.ops.aten.addmm.default](args = (%p_net_linear2_bias, %sigmoid, %permute_1), kwargs = {})  | 
 | 145 | +    %_log_softmax : [num_users=3] = call_function[target=torch.ops.aten._log_softmax.default](args = (%addmm_1, 1, False), kwargs = {})  | 
 | 146 | +    %alias_2 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%_log_softmax,), kwargs = {})  | 
 | 147 | +    %alias_3 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_2,), kwargs = {})  | 
 | 148 | +    %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})  | 
 | 149 | +    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})  | 
 | 150 | +    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %label, %scalar_tensor), kwargs = {})  | 
 | 151 | +    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})  | 
 | 152 | +    %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%_log_softmax, 1, %unsqueeze), kwargs = {})  | 
 | 153 | +    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dims](args = (%gather, [1]), kwargs = {})  | 
 | 154 | +    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})  | 
 | 155 | +    %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})  | 
 | 156 | +    %scalar_tensor_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})  | 
 | 157 | +    %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %scalar_tensor_1), kwargs = {})  | 
 | 158 | +    %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%label, -100), kwargs = {})  | 
 | 159 | +    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%ne_2, []), kwargs = {})  | 
 | 160 | +    %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%sum_1,), kwargs = {dtype: torch.float32, device: cpu})  | 
 | 161 | +    %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%where_1, []), kwargs = {})  | 
 | 162 | +    %div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_2, %_to_copy), kwargs = {})  | 
 | 163 | +    %alias_4 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%addmm_1,), kwargs = {})  | 
 | 164 | +    %alias_5 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_4,), kwargs = {})  | 
 | 165 | +    %alias_6 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_5,), kwargs = {})  | 
 | 166 | +    %argmax : [num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%alias_6, 1), kwargs = {})  | 
 | 167 | +    %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%div, 1), kwargs = {pin_memory: False, memory_format: torch.preserve_format})  | 
 | 168 | +    %div_1 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%full_like, %_to_copy), kwargs = {})  | 
 | 169 | +    %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%label, 1), kwargs = {})  | 
 | 170 | +    %ne_3 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})  | 
 | 171 | +    %scalar_tensor_2 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu})  | 
 | 172 | +    %where_2 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_1, %scalar_tensor_2), kwargs = {})  | 
 | 173 | +    %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%_log_softmax, 0), kwargs = {pin_memory: False, memory_format: torch.preserve_format})  | 
 | 174 | +    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.value](args = (%full_like_1, 1, %where_2, -1.0), kwargs = {})  | 
 | 175 | +    %ne_4 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_1, -100), kwargs = {})  | 
 | 176 | +    %scalar_tensor_3 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (0,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})  | 
 | 177 | +    %where_3 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_4, %div_1, %scalar_tensor_3), kwargs = {})  | 
 | 178 | +    %mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter, %where_3), kwargs = {})  | 
 | 179 | +    %alias_7 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_3,), kwargs = {})  | 
 | 180 | +    %alias_8 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_7,), kwargs = {})  | 
 | 181 | +    %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%alias_8,), kwargs = {})  | 
 | 182 | +    %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1], True), kwargs = {})  | 
 | 183 | +    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp, %sum_3), kwargs = {})  | 
 | 184 | +    %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul, %mul_1), kwargs = {})  | 
 | 185 | +    %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_1, [1, 0]), kwargs = {})  | 
 | 186 | +    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%sub, %permute_2), kwargs = {})  | 
 | 187 | +    %permute_3 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sub, [1, 0]), kwargs = {})  | 
 | 188 | +    %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_3, %sigmoid), kwargs = {})  | 
 | 189 | +    %permute_4 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_1, [1, 0]), kwargs = {})  | 
 | 190 | +    %sum_4 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%sub, [0], True), kwargs = {})  | 
 | 191 | +    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_4, [2]), kwargs = {})  | 
 | 192 | +    %permute_5 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_4, [1, 0]), kwargs = {})  | 
 | 193 | +    %alias_9 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%alias_1,), kwargs = {})  | 
 | 194 | +    %alias_10 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%alias_9,), kwargs = {})  | 
 | 195 | +    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1, %alias_10), kwargs = {})  | 
 | 196 | +    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%alias_10, %sub_1), kwargs = {})  | 
 | 197 | +    %mul_3 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm, %mul_2), kwargs = {})  | 
 | 198 | +    %permute_6 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mul_3, [1, 0]), kwargs = {})  | 
 | 199 | +    %mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%permute_6, %input), kwargs = {})  | 
 | 200 | +    %permute_7 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%mm_2, [1, 0]), kwargs = {})  | 
 | 201 | +    %sum_5 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_3, [0], True), kwargs = {})  | 
 | 202 | +    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%sum_5, [10]), kwargs = {})  | 
 | 203 | +    %permute_8 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute_7, [1, 0]), kwargs = {})  | 
 | 204 | +    return (div, argmax, permute_8, view_1, permute_5, view)  | 
 | 205 | +```  | 
 | 206 | + | 
 | 207 | +Its a lot bigger! We call this the 'joint graph' or the 'forwards backwards graph'. We have explicitly captured the backwards graph  | 
 | 208 | +alongside the forward and now our model returns [Loss, Any other user outputs, Gradients].  | 
 | 209 | + | 
 | 210 | +From here we can lower the rest of the way to ExecuTorch  | 
 | 211 | +```python  | 
 | 212 | +ep = to_edge(ep)  | 
 | 213 | + | 
 | 214 | +# After calling to_executorch the weights themselves are also appended to the model outputs. This is to make  | 
 | 215 | +# some downstream passes like memory planning a little easier. A couple of hidden utility functions are also  | 
 | 216 | +# embedded in the model __et_training_gradients_index_<method_name>,  | 
 | 217 | +# __et_training_parameters_index_<method_name>, __et_training_fqn_<method_name>.  | 
 | 218 | +#  | 
 | 219 | +# These help us partition the huge list of model outputs into meaningful sections as well as assign names to each weight/gradient.  | 
 | 220 | +ep = ep.to_executorch()  | 
 | 221 | + | 
 | 222 | +with open("xor.pte", "wb") as file:  | 
 | 223 | +    ep.write_to_file(file)  | 
 | 224 | +```  | 
 | 225 | + | 
 | 226 | +### Run the model train script with CMAKE  | 
 | 227 | +After exporting the model for training, we can now try learning using CMake. We can build and use the train_xor, which is a sample wrapper for the ExecuTorch Runtime, TrainingModule, and SGD optimizer. We first begin by configuring the CMake build like such:  | 
 | 228 | +```bash  | 
 | 229 | +# cd to the root of executorch repo  | 
 | 230 | +cd executorch  | 
 | 231 | + | 
 | 232 | +# Get a clean cmake-out directory  | 
 | 233 | +rm -rf cmake-out  | 
 | 234 | +mkdir cmake-out  | 
 | 235 | + | 
 | 236 | +# Configure cmake  | 
 | 237 | +cmake \  | 
 | 238 | +    -DCMAKE_INSTALL_PREFIX=cmake-out \  | 
 | 239 | +    -DCMAKE_BUILD_TYPE=Release \  | 
 | 240 | +    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \  | 
 | 241 | +    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \  | 
 | 242 | +    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \  | 
 | 243 | +    -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \  | 
 | 244 | +    -DEXECUTORCH_ENABLE_LOGGING=ON \  | 
 | 245 | +    -DPYTHON_EXECUTABLE=python \  | 
 | 246 | +    -Bcmake-out .  | 
 | 247 | +```  | 
 | 248 | +Then you can build the runtime componenets with  | 
 | 249 | + | 
 | 250 | +```bash  | 
 | 251 | +cmake --build cmake-out -j9 --target install --config Release  | 
 | 252 | +```  | 
 | 253 | + | 
 | 254 | +Now you should be able to find the executable built at `./cmake-out/extension/training/train_xor` you can run the executable with the model you generated as such  | 
 | 255 | +```bash  | 
 | 256 | +./cmake-out/extension/training/train_xor --model_path=./xor.pte  | 
 | 257 | +```  | 
 | 258 | + | 
 | 259 | +## What is missing?/ What is next?  | 
 | 260 | +A ton! ExecuTorch training is still quite experimental and under heavy active development. Whats here currently is more of a technical preview.  | 
 | 261 | + | 
 | 262 | +The _export_forward_backward is not very stable yet and may fail on more complicated model architectures, though we have verified it works for LoRA with LLMs.  | 
 | 263 | + | 
 | 264 | +The ExecuTorch portable operator lib does not yet have full coverage of ops that might show up in the backwards graphs.  | 
 | 265 | + | 
 | 266 | +We don't have a way yet to serialize the newly trained weights natively in ExecuTorch (though you can convert them to ATen tensors using extension/aten_util and then serialize them using ATen APIs).  | 
 | 267 | + | 
 | 268 | +We plan to add a way to update models in place on-device (will be needed for finetuning).  | 
 | 269 | + | 
 | 270 | +We are looking to integrate with many of the existing delegates/backends on ET enabling accelerated training.  | 
 | 271 | + | 
 | 272 | +and so much more!  | 
 | 273 | + | 
 | 274 | +## Help & Improvements  | 
 | 275 | +If you have problems or questions, or have suggestions for ways to make  | 
 | 276 | +implementation and testing better, please reach out to the PyTorch Edge team or  | 
 | 277 | +create an issue on [github](https://www.github.com/pytorch/executorch/issues).  | 
0 commit comments