Skip to content

Commit 7d8cb1b

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
SGD and TrainingModule in Python (pytorch#5847)
Summary: Add short term pybindings and training module api. The optimizer bindings are probably fairly stable in practice, but the training module is not. Training module is a wrapper around the existing ET pybindings which are under active development. A future PR will update the training ones to match the long term inference bindings. Reviewed By: dvorjackz Differential Revision: D63650449
1 parent cbfdf78 commit 7d8cb1b

File tree

12 files changed

+454
-4
lines changed

12 files changed

+454
-4
lines changed

examples/llm_pte_finetuning/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main() -> None:
9898
# for us to update with the gradients in-place.
9999
# See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736
100100
# for more info.
101-
out = et_mod.forward((tokens, labels), clone_outputs=False) # pyre-ignore
101+
out = et_mod.forward((tokens, labels), clone_outputs=False)
102102

103103
loss = out[0]
104104
losses.append(loss.item())

extension/pybindings/pybindings.pyi

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,20 @@ class ExecuTorchModule:
3333
"""
3434

3535
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
36-
def __call__(self, inputs: Any) -> List[Any]: ...
36+
def __call__(self, inputs: Any, clone_outputs: bool = True) -> List[Any]: ...
3737
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
38-
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
38+
def run_method(
39+
self,
40+
method_name: str,
41+
inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations.
42+
clone_outputs: bool = True,
43+
) -> List[Any]: ...
3944
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
40-
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
45+
def forward(
46+
self,
47+
inputs: Sequence[Any], # pyre-ignore[2]: "Any" in parameter type annotations.
48+
clone_outputs: bool = True,
49+
) -> List[Any]: ...
4150
# pyre-ignore[3]: "Any" in return type annotations.
4251
def plan_execute(self) -> List[Any]: ...
4352
# Bundled program methods.

extension/training/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
python_library(
12+
name = "lib",
13+
srcs = [
14+
"__init__.py",
15+
],
16+
deps = [
17+
"//executorch/extension/training/pybindings:_training_lib",
18+
"//executorch/extension/training/pybindings:_training_module",
19+
],
20+
)

extension/training/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.extension.training.pybindings._training_lib import get_sgd_optimizer
8+
9+
from executorch.extension.training.pybindings._training_module import (
10+
_load_for_executorch_for_training,
11+
_load_for_executorch_for_training_from_buffer,
12+
TrainingModule,
13+
)
14+
15+
__all__ = [
16+
"get_sgd_optimizer",
17+
"TrainingModule",
18+
"_load_for_executorch_for_training_from_buffer",
19+
"_load_for_executorch_for_training",
20+
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
10+
11+
runtime.cxx_python_extension(
12+
name = "_training_lib",
13+
srcs = [
14+
"_training_lib.cpp",
15+
],
16+
base_module = "executorch.extension.training.pybindings",
17+
types = ["_training_lib.pyi"],
18+
visibility = ["//executorch/extension/training/..."],
19+
deps = [
20+
"//executorch/extension/aten_util:aten_bridge",
21+
"//executorch/extension/training/optimizer:sgd",
22+
],
23+
external_deps = [
24+
"pybind11",
25+
"libtorch_python",
26+
],
27+
)
28+
29+
runtime.python_library(
30+
name = "_training_module",
31+
srcs = [
32+
"_training_module.py",
33+
],
34+
base_module = "executorch.extension.training.pybindings",
35+
visibility = ["//executorch/extension/training/..."],
36+
deps = [
37+
"//caffe2:torch",
38+
"//executorch/extension/pybindings:portable_lib",
39+
],
40+
)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <pybind11/pybind11.h>
10+
#include <pybind11/stl.h>
11+
#include <memory>
12+
13+
#include <ATen/Tensor.h>
14+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15+
#include <torch/csrc/utils/pybind.h>
16+
#include "executorch/extension/tensor/tensor.h"
17+
#include "executorch/extension/training/optimizer/sgd.h"
18+
#ifndef USE_ATEN_LIB
19+
#include <executorch/extension/aten_util/aten_bridge.h>
20+
#endif
21+
22+
namespace py = pybind11;
23+
24+
namespace executorch {
25+
namespace extension {
26+
namespace training {
27+
28+
namespace {
29+
30+
struct PySGD final {
31+
explicit PySGD(
32+
const py::dict& named_params,
33+
double lr,
34+
double momentum,
35+
double dampening,
36+
double weight_decay,
37+
bool nesterov)
38+
: sgd_(nullptr),
39+
fqns_()
40+
#ifndef USE_ATEN_LIB
41+
,
42+
params_()
43+
#endif
44+
{
45+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
46+
auto py_named_params =
47+
py::cast<std::unordered_map<std::string, at::Tensor>>(named_params);
48+
const auto params_size = py::len(named_params);
49+
fqns_ = std::vector<std::string>();
50+
fqns_.reserve(params_size);
51+
52+
for (auto pair : py_named_params) {
53+
fqns_.push_back(pair.first);
54+
exec_aten::string_view v{fqns_.back().c_str(), pair.first.size()};
55+
#ifndef USE_ATEN_LIB
56+
// convert at::Tensor to torch::executor::Tensor
57+
params_.emplace_back(alias_tensor_ptr_to_attensor(pair.second));
58+
cpp_inputs.insert({v, *params_.back()});
59+
#else
60+
cpp_inputs.insert({v, pair.second});
61+
#endif
62+
}
63+
sgd_ = std::make_unique<optimizer::SGD>(
64+
cpp_inputs,
65+
extension::training::optimizer::SGDOptions(
66+
lr, momentum, dampening, weight_decay, nesterov));
67+
}
68+
69+
// Not needed for now, so just delete.
70+
PySGD(const PySGD&) = delete;
71+
PySGD& operator=(const PySGD&) = delete;
72+
PySGD(PySGD&&) = delete;
73+
PySGD& operator=(PySGD&&) = delete;
74+
75+
void step(const py::dict& py_dict) {
76+
auto py_named_gradients =
77+
py::cast<std::unordered_map<std::string, at::Tensor>>(py_dict);
78+
const auto inputs_size = py::len(py_dict);
79+
std::map<exec_aten::string_view, exec_aten::Tensor> cpp_inputs;
80+
81+
std::vector<std::string> fqn;
82+
#ifndef USE_ATEN_LIB
83+
std::vector<TensorPtr> et_tensors;
84+
#endif
85+
86+
// Convert python objects into cpp.
87+
for (const auto& pair : py_named_gradients) {
88+
fqn.push_back(pair.first);
89+
auto at_tensor = pair.second;
90+
// alias_etensor_to_attensor will assert on this later, so to better
91+
// propogate up to python we check early and throw an exception.
92+
if (!at_tensor.is_contiguous()) {
93+
auto error_msg = "Gradient is not contiguous.";
94+
throw std::runtime_error(error_msg);
95+
}
96+
#ifndef USE_ATEN_LIB
97+
// convert at::Tensor to torch::executor::Tensor
98+
auto temp = alias_tensor_ptr_to_attensor(at_tensor);
99+
et_tensors.push_back(temp);
100+
cpp_inputs.insert({pair.first.c_str(), *et_tensors.back()});
101+
#else
102+
cpp_inputs.insert({pair.first.c_str(), at_tensor});
103+
#endif
104+
}
105+
106+
auto err = sgd_->step(cpp_inputs);
107+
if (err != runtime::Error::Ok) {
108+
throw std::runtime_error("SGD step failed");
109+
}
110+
}
111+
112+
private:
113+
// TODO(jakeszwe): Write an optimizer interface and use it here instead of SGD
114+
// specifically.
115+
std::unique_ptr<optimizer::SGD> sgd_ = nullptr;
116+
std::vector<std::string> fqns_;
117+
118+
#ifndef USE_ATEN_LIB // Portable mode
119+
std::vector<TensorPtr> params_;
120+
#endif
121+
;
122+
};
123+
124+
static std::unique_ptr<PySGD> get_sgd_optimizer(
125+
const py::dict& named_params,
126+
double lr,
127+
double momentum = 0,
128+
double dampening = 0,
129+
double weight_decay = 0,
130+
bool nesterov = false) {
131+
return std::make_unique<PySGD>(
132+
named_params, lr, momentum, dampening, weight_decay, nesterov);
133+
}
134+
135+
} // namespace
136+
137+
PYBIND11_MODULE(_training_lib, m) {
138+
m.def(
139+
"get_sgd_optimizer",
140+
&get_sgd_optimizer,
141+
py::arg("named_params"),
142+
py::arg("lr") = 0.1,
143+
py::arg("momentum") = 0.0,
144+
py::arg("dampening") = 0.0,
145+
py::arg("weight_decay") = 0.0,
146+
py::arg("nesterov") = false);
147+
py::class_<PySGD>(m, "ExecuTorchSGD").def("step", &PySGD::step);
148+
}
149+
150+
} // namespace training
151+
} // namespace extension
152+
} // namespace executorch
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Any, Dict, List, Optional, Sequence, Tuple
10+
11+
from executorch.exir._warnings import experimental
12+
from torch import Tensor
13+
14+
@experimental("This API is experimental and subject to change without notice.")
15+
class ExecuTorchSGD:
16+
"""SGD Optimizer.
17+
18+
.. warning::
19+
20+
This API is experimental and subject to change without notice.
21+
"""
22+
23+
def step(self, named_gradients: Dict[str, Tensor]) -> None:
24+
"""Take a step in the direction of the gradients."""
25+
...
26+
27+
@experimental("This API is experimental and subject to change without notice.")
28+
def get_sgd_optimizer(
29+
named_parameters: Dict[str, Tensor],
30+
lr: float,
31+
momentum: float = 0,
32+
dampening: float = 0,
33+
weight_decay: float = 0,
34+
nesterov: bool = False,
35+
) -> ExecuTorchSGD:
36+
"""Creates an sgd optimizer that operates on the passed in named_parameters according to the specified hyper parameters.
37+
38+
.. warning::
39+
40+
This API is experimental and subject to change without notice.
41+
...
42+
"""
43+
...
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, List, Sequence
8+
9+
from executorch.exir._warnings import experimental
10+
11+
from executorch.extension.pybindings.portable_lib import (
12+
_load_for_executorch,
13+
_load_for_executorch_from_buffer,
14+
ExecuTorchModule,
15+
)
16+
from torch import Tensor
17+
18+
19+
@experimental("This API is experimental and subject to change without notice.")
20+
class TrainingModule:
21+
def __init__(self, module: ExecuTorchModule):
22+
self.model = module
23+
24+
self.gradients_method_prefix = "__et_training_gradients_index_"
25+
self.parameters_method_prefix = "__et_training_parameters_index_"
26+
self.fqn_method_prefix = "__et_training_fqn_"
27+
28+
self.named_grads = None
29+
self.named_params = None
30+
31+
def forward_backward(self, method_name: str, inputs: Sequence[Any]) -> List[Any]:
32+
# The default ET model returns a large list of outputs that can logically be
33+
# separated into [user outputs, gradients, parameters]. Can use these metadata
34+
# methods to slice the list into the correct parts.
35+
grad_start_idx = self.model.run_method(
36+
self.gradients_method_prefix + method_name, ()
37+
)[0]
38+
params_start_idx = self.model.run_method(
39+
self.parameters_method_prefix + method_name, ()
40+
)[0]
41+
42+
full_outputs = self.model.run_method(method_name, inputs)
43+
44+
user_outs = full_outputs[:grad_start_idx]
45+
grads = full_outputs[grad_start_idx:params_start_idx]
46+
params = full_outputs[params_start_idx:]
47+
48+
# Important that the outputs are not cloned because we need the optimizer to
49+
# be able to mutate the actual weights and not clones of them.
50+
fqn = self.model.run_method(
51+
self.fqn_method_prefix + method_name, (), clone_outputs=False
52+
)
53+
54+
self.named_grads = dict(zip(fqn, grads))
55+
if self.named_params is None:
56+
self.named_params = dict(zip(fqn, params))
57+
58+
return user_outs
59+
60+
def named_gradients(self) -> Dict[str, Tensor]:
61+
if self.named_grads is None:
62+
raise RuntimeError("Must call forward_backward before named_grads")
63+
return self.named_grads
64+
65+
def named_parameters(self) -> Dict[str, Tensor]:
66+
if self.named_grads is None:
67+
raise RuntimeError(
68+
"Must call forward_backward before named_params. This will be fixed in a later version"
69+
)
70+
return self.named_params
71+
72+
73+
@experimental("This API is experimental and subject to change without notice.")
74+
def _load_for_executorch_for_training(path: str) -> TrainingModule:
75+
et_module = _load_for_executorch(path)
76+
return TrainingModule(et_module)
77+
78+
79+
@experimental("This API is experimental and subject to change without notice.")
80+
def _load_for_executorch_for_training_from_buffer(
81+
buffer: bytes,
82+
) -> TrainingModule:
83+
et_module = _load_for_executorch_from_buffer(buffer)
84+
return TrainingModule(et_module)

0 commit comments

Comments
 (0)