Skip to content

Commit 7a54213

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
SGD and TrainingModule in Python (#5847)
Summary: Add short term pybindings and training module api. The optimizer bindings are probably fairly stable in practice, but the training module is not. Training module is a wrapper around the existing ET pybindings which are under active development. A future PR will update the training ones to match the long term inference bindings. Differential Revision: D63650449
1 parent 98c5efa commit 7a54213

File tree

11 files changed

+446
-3
lines changed

11 files changed

+446
-3
lines changed

extension/pybindings/pybindings.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ class ExecuTorchModule:
2121
"""
2222

2323
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
24-
def __call__(self, inputs: Any) -> List[Any]: ...
24+
def __call__(self, inputs: Any, clone_outputs: bool = True) -> List[Any]: ...
2525
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
26-
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
26+
def run_method(
27+
self, method_name: str, inputs: Sequence[Any], clone_outputs: bool = True # pyre-ignore[2]: "Any" in parameter type annotations.
28+
) -> List[Any]: ...
2729
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
28-
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
30+
def forward(
31+
self, inputs: Sequence[Any], clone_outputs: bool = True # pyre-ignore[2]: "Any" in parameter type annotations.
32+
) -> List[Any]: ...
2933
# pyre-ignore[3]: "Any" in return type annotations.
3034
def plan_execute(self) -> List[Any]: ...
3135
# Bundled program methods.

extension/training/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
python_library(
12+
name = "lib",
13+
srcs = [
14+
"__init__.py",
15+
],
16+
deps = [
17+
"//executorch/extension/training/pybindings:_training_lib",
18+
"//executorch/extension/training/pybindings:_training_module",
19+
],
20+
)

extension/training/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.extension.training.pybindings._training_lib import get_sgd_optimizer
8+
9+
from executorch.extension.training.pybindings._training_module import (
10+
_load_for_executorch_for_training,
11+
_load_for_executorch_for_training_from_buffer,
12+
TrainingModule,
13+
)
14+
15+
__all__ = [
16+
"get_sgd_optimizer",
17+
"TrainingModule",
18+
"_load_for_executorch_for_training_from_buffer",
19+
"_load_for_executorch_for_training",
20+
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
runtime.cxx_python_extension(
12+
name = "_training_lib",
13+
srcs = [
14+
"_training_lib.cpp",
15+
],
16+
base_module = "executorch.extension.training.pybindings",
17+
# types = ["_training_lib.pyi"],
18+
visibility = ["//executorch/extension/training/..."],
19+
deps = [
20+
"//executorch/extension/aten_util:aten_bridge",
21+
"//executorch/extension/training/optimizer:sgd",
22+
],
23+
external_deps = [
24+
"pybind11",
25+
"libtorch_python",
26+
],
27+
)
28+
29+
runtime.python_library(
30+
name = "_training_module",
31+
srcs = [
32+
"_training_module.py",
33+
],
34+
base_module = "executorch.extension.training.pybindings",
35+
visibility = ["//executorch/extension/training/..."],
36+
deps = [
37+
"//caffe2:torch",
38+
"//executorch/extension/pybindings:portable_lib",
39+
],
40+
)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pybind11/pybind11.h>
10+
#include <pybind11/stl.h>
11+
#include <memory>
12+
#include <stack>
13+
14+
#include <ATen/Tensor.h>
15+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16+
#include <torch/csrc/utils/pybind.h>
17+
#include "executorch/extension/tensor/tensor.h"
18+
#include "executorch/extension/training/module/training_module.h"
19+
#include "executorch/extension/training/optimizer/sgd.h"
20+
#ifndef USE_ATEN_LIB
21+
#include <c10/core/impl/LocalDispatchKeySet.h>
22+
#include <executorch/extension/aten_util/aten_bridge.h>
23+
#endif
24+
25+
namespace py = pybind11;
26+
27+
namespace executorch {
28+
namespace extension {
29+
namespace training {
30+
31+
namespace {
32+
33+
struct PySGD final {
34+
explicit PySGD(
35+
const py::dict& named_params,
36+
double lr,
37+
double momentum,
38+
double dampening,
39+
double weight_decay,
40+
bool nesterov)
41+
: sgd_(nullptr),
42+
fqns_()
43+
#ifndef USE_ATEN_LIB
44+
,
45+
params_()
46+
#endif
47+
{
48+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
49+
auto py_named_params =
50+
py::cast<std::unordered_map<std::string, at::Tensor>>(named_params);
51+
const auto params_size = py::len(named_params);
52+
fqns_ = std::vector<std::string>();
53+
fqns_.reserve(params_size);
54+
55+
for (auto pair : py_named_params) {
56+
fqns_.push_back(pair.first);
57+
exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()};
58+
#ifndef USE_ATEN_LIB
59+
// convert at::Tensor to torch::executor::Tensor
60+
params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second));
61+
cpp_inputs.insert({v, *params_.back()});
62+
#else
63+
cpp_inputs.insert({v, pair.second});
64+
#endif
65+
}
66+
sgd_ = std::make_unique<optimizer::SGD>(
67+
cpp_inputs,
68+
extension::training::optimizer::SGDOptions(
69+
lr, momentum, dampening, weight_decay, nesterov));
70+
}
71+
72+
PySGD(const PySGD&) = delete;
73+
PySGD& operator=(const PySGD&) = delete;
74+
PySGD(PySGD&&) = default;
75+
PySGD& operator=(PySGD&&) = default;
76+
77+
void step(const py::dict& py_dict) {
78+
auto py_named_gradients =
79+
py::cast<std::unordered_map<std::string, at::Tensor>>(py_dict);
80+
const auto inputs_size = py::len(py_dict);
81+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
82+
83+
std::vector<std::string> fqn;
84+
#ifndef USE_ATEN_LIB
85+
std::vector<TensorPtr> et_tensors;
86+
#endif
87+
88+
// Convert python objects into cpp.
89+
for (const auto& pair : py_named_gradients) {
90+
fqn.push_back(pair.first);
91+
auto at_tensor = pair.second;
92+
// alias_etensor_to_attensor will assert on this later, so to better
93+
// propogate up to python we check early and throw an exception.
94+
if (!at_tensor.is_contiguous()) {
95+
auto error_msg = "Gradient is not contiguous.";
96+
throw std::runtime_error(error_msg);
97+
}
98+
#ifndef USE_ATEN_LIB
99+
// convert at::Tensor to torch::executor::Tensor
100+
auto temp = alias_tensor_ptr_to_attensor(at_tensor);
101+
et_tensors.push_back(temp);
102+
cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()});
103+
#else
104+
cpp_inputs.insert({pair.first.c_str(), at_tensor});
105+
#endif
106+
}
107+
108+
auto err = sgd_->step(cpp_inputs);
109+
if (err != runtime::Error::Ok) {
110+
throw std::runtime_error("SGD step failed");
111+
}
112+
}
113+
114+
private:
115+
std::unique_ptr<optimizer::SGD> sgd_ = nullptr;
116+
std::vector<std::string> fqns_;
117+
118+
#ifndef USE_ATEN_LIB // Portable mode
119+
std::vector<TensorPtr> params_;
120+
#endif
121+
;
122+
};
123+
124+
static std::unique_ptr<PySGD> get(
125+
const py::dict& named_params,
126+
double lr,
127+
double momentum = 0,
128+
double dampening = 0,
129+
double weight_decay = 0,
130+
bool nesterov = false) {
131+
return std::make_unique<PySGD>(
132+
named_params, lr, momentum, dampening, weight_decay, nesterov);
133+
}
134+
135+
} // namespace
136+
137+
PYBIND11_MODULE(_training_lib, m) {
138+
m.def(
139+
"get_sgd_optimizer",
140+
&get,
141+
py::arg("named_params"),
142+
py::arg("lr") = 0.1,
143+
py::arg("momentum") = 0.0,
144+
py::arg("dampening") = 0.0,
145+
py::arg("weight_decay") = 0.0,
146+
py::arg("nesterov") = false);
147+
py::class_<PySGD>(m, "ExecuTorchSGD").def("step", &PySGD::step);
148+
}
149+
150+
} // namespace training
151+
} // namespace extension
152+
} // namespace executorch
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from __future__ import annotations
9+
10+
from typing import Any, Dict, List, Optional, Sequence, Tuple
11+
12+
from executorch.exir._warnings import experimental
13+
from torch import Tensor
14+
15+
@experimental("This API is experimental and subject to change without notice.")
16+
class ExecuTorchSGD:
17+
"""SGD Optimizer.
18+
19+
.. warning::
20+
21+
This API is experimental and subject to change without notice.
22+
"""
23+
24+
def step(self, named_gradients: Dict[str, Tensor]) -> None:
25+
"""Take a step in the direction of the gradients."""
26+
...
27+
28+
@experimental("This API is experimental and subject to change without notice.")
29+
def get_sgd_optimizer(
30+
named_parameters: Dict[str, Tensor],
31+
lr: float,
32+
momentum: float = 0,
33+
dampening: float = 0,
34+
weight_decay: float = 0,
35+
nesterov: bool = False,
36+
) -> ExecuTorchSGD:
37+
"""Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters.
38+
39+
.. warning::
40+
41+
This API is experimental and subject to change without notice.
42+
...
43+
"""
44+
...
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, List, Sequence
8+
9+
from executorch.exir._warnings import experimental
10+
11+
from executorch.extension.pybindings.portable_lib import (
12+
_load_for_executorch,
13+
_load_for_executorch_from_buffer,
14+
ExecuTorchModule,
15+
)
16+
from torch import Tensor
17+
18+
19+
@experimental("This API is experimental and subject to change without notice.")
20+
class TrainingModule:
21+
def __init__(self, module: ExecuTorchModule):
22+
self.model = module
23+
24+
self.gradients_method_prefix = "__et_training_gradients_index_"
25+
self.parameters_method_prefix = "__et_training_parameters_index_"
26+
self.fqn_method_prefix = "__et_training_fqn_"
27+
28+
self.named_grads = None
29+
self.named_params = None
30+
31+
def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]:
32+
# The default ET model returns a large list of outputs that can logically be
33+
# seperated into [user outputs, gradients, parameters]. Can use these metadata
34+
# methods to slice the list into the correct parts.
35+
grad_start_idx = self.model.run_method(
36+
self.gradients_method_prefix + method_name, ()
37+
)[0]
38+
params_start_idx = self.model.run_method(
39+
self.parameters_method_prefix + method_name, ()
40+
)[0]
41+
42+
full_outputs = self.model.run_method(method_name, inputs)
43+
44+
user_outs = full_outputs[:grad_start_idx]
45+
grads = full_outputs[grad_start_idx:params_start_idx]
46+
params = full_outputs[params_start_idx:]
47+
48+
fqn = self.model.run_method(
49+
self.fqn_method_prefix + method_name, (), clone_outputs=False
50+
)
51+
52+
self.named_grads = dict(zip(fqn, grads))
53+
if self.named_params is None:
54+
self.named_params = dict(zip(fqn, params))
55+
56+
return user_outs
57+
58+
def named_gradients(self) -> Dict[str, Tensor]:
59+
if self.named_grads is None:
60+
raise RuntimeError("Must call forward_backward before named_grads")
61+
return self.named_grads
62+
63+
def named_parameters(self) -> Dict[str, Tensor]:
64+
if self.named_grads is None:
65+
raise RuntimeError(
66+
"Must call forward_backward before named_params. This will be fixed in a later version"
67+
)
68+
return self.named_params
69+
70+
71+
@experimental("This API is experimental and subject to change without notice.")
72+
def _load_for_executorch_for_training(path: str) -> TrainingModule:
73+
et_module = _load_for_executorch(path)
74+
return TrainingModule(et_module)
75+
76+
77+
@experimental("This API is experimental and subject to change without notice.")
78+
def _load_for_executorch_for_training_from_buffer(
79+
buffer: bytes,
80+
) -> TrainingModule:
81+
et_module = _load_for_executorch_from_buffer(buffer)
82+
return TrainingModule(et_module)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def define_common_targets():
2+
"""Defines targets that should be shared between fbcode and xplat.
3+
4+
The directory containing this targets.bzl file should also contain both
5+
TARGETS and BUCK files that call this function.
6+
"""
7+
8+
pass

0 commit comments

Comments
 (0)