Skip to content

Commit f0baf1d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
training module takes .ptd
Summary: Allow TrainingModule to take in a .ptd. Also realized I was only caching the gradient tensors not the params so went ahead and fixed that. Updated the export script to generate training modules with separated weights. Fixed a bug in tensor parsing for external mutable tensors. Differential Revision: D69547105
1 parent 84273f4 commit f0baf1d

File tree

7 files changed

+129
-60
lines changed

7 files changed

+129
-60
lines changed

extension/training/module/test/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def define_common_targets(is_fbcode = False):
1717
# intentionally don't work in xplat (since they're host-only tools).
1818
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
1919
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
20+
"ET_MODULE_TRAIN_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleSimpleTrainProgram.pte])",
21+
"ET_MODULE_TRAIN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleSimpleTrain.ptd])",
2022
}
2123

2224
runtime.cxx_test(
@@ -28,6 +30,7 @@ def define_common_targets(is_fbcode = False):
2830
"//executorch/extension/training/module:training_module",
2931
"//executorch/extension/data_loader:file_data_loader",
3032
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
33+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
3134
"//executorch/kernels/portable:generated_lib",
3235
],
3336
env = modules_env,

extension/training/module/test/training_module_test.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/extension/data_loader/file_data_loader.h>
1010
#include <executorch/extension/training/module/training_module.h>
11+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1112

1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1314
#include <executorch/runtime/platform/runtime.h>
@@ -21,6 +22,14 @@ using executorch::aten::Tensor;
2122
using torch::executor::Error;
2223
using torch::executor::Span;
2324
using torch::executor::testing::TensorFactory;
25+
using executorch::extension::FlatTensorDataMap;
26+
using executorch::extension::FlatTensorHeader;
27+
using executorch::runtime::DataLoader;
28+
using executorch::runtime::Error;
29+
using executorch::runtime::FreeableBuffer;
30+
using executorch::runtime::Result;
31+
using executorch::runtime::TensorLayout;
32+
using torch::executor::util::FileDataLoader;
2433

2534
class TrainingModuleTest : public ::testing::Test {
2635
protected:
@@ -105,3 +114,36 @@ TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
105114
auto res = mod.execute_forward_backward("forward", inputs);
106115
ASSERT_EQ(res.error(), Error::InvalidArgument);
107116
}
117+
118+
TEST_F(TrainingModuleTest, SeperateDataTest) {
119+
// Load data map.
120+
// The eager linear model is defined at:
121+
// //executorch/test/models/linear_model.py
122+
const char* ptd_path = std::getenv("ET_MODULE_TRAIN_DATA_PATH");
123+
Result<FileDataLoader> data_map_loader_res = FileDataLoader::from(ptd_path);
124+
ASSERT_EQ(data_map_loader_res.error(), Error::Ok);
125+
126+
auto data_map_loader = std::make_unique<torch::executor::util::FileDataLoader>(
127+
std::move(data_map_loader_res.get()));
128+
129+
const char* pte_path = std::getenv("ET_MODULE_TRAIN_PROGRAM_PATH");
130+
Result<FileDataLoader> pte_loader_res = FileDataLoader::from(pte_path);
131+
ASSERT_EQ(pte_loader_res.error(), Error::Ok);
132+
133+
auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
134+
std::move(pte_loader_res.get()));
135+
136+
auto mod = executorch::extension::training::TrainingModule(std::move(pte_loader), nullptr, nullptr, nullptr, std::move(data_map_loader));
137+
138+
TensorFactory<ScalarType::Float> tf;
139+
Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
140+
Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
141+
142+
std::vector<executorch::runtime::EValue> inputs;
143+
inputs.push_back(input);
144+
inputs.push_back(label);
145+
146+
auto res = mod.execute_forward_backward("forward", inputs);
147+
ASSERT_EQ(res.error(), Error::Ok);
148+
ASSERT_EQ(res.get().size(), 1);
149+
}

extension/training/module/training_module.cpp

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ TrainingModule::execute_forward_backward(
4343
uint64_t param_start = param_res.get()[0].toInt();
4444

4545
// Execute the forward and backward pass.
46-
4746
auto outputs = torch::executor::Module::execute(method_name, input);
4847
if (!outputs.ok()) {
4948
return outputs.error();
@@ -56,19 +55,22 @@ TrainingModule::execute_forward_backward(
5655
user_outputs.push_back(outputs.get().at(i));
5756
}
5857

59-
// Extract and store the gradients.
58+
// Extract and store the gradients and params if this is the first time seeing this method.
6059
if (method_named_gradients_.find(method_name) ==
6160
method_named_gradients_.end()) {
61+
// Fully qualified names
62+
std::vector<runtime::EValue> fqn_list;
6263
method_named_gradients_.insert({method_name, {}});
6364

6465
auto& gradients_map = method_named_gradients_.at(method_name);
65-
// Get names.
66+
67+
// Get names if we havent seen this method before.
6668
const std::string fqn_method_name = fqn_method_prefix + method_name;
6769
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
6870
if (!fqn_res.ok()) {
6971
return fqn_res.error();
7072
}
71-
const auto& fqn_list = fqn_res.get();
73+
fqn_list = fqn_res.get();
7274

7375
// Only have to initialize the dict once because the tensors in the dict and
7476
// the tensors in the method alias the same TensorImpl, so updating one will
@@ -87,43 +89,48 @@ TrainingModule::execute_forward_backward(
8789
runtime::Result<
8890
const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
8991
TrainingModule::named_parameters(const std::string& method_name) {
90-
std::map<executorch::aten::string_view, executorch::aten::Tensor>
91-
named_parameters;
92-
const std::string fqn_method_name = fqn_method_prefix + method_name;
93-
const std::string parameters_method_name =
94-
parameters_method_prefix + method_name;
92+
// If we haven't seen this method before, populate the dict.
93+
if (method_named_parameters_.find(method_name) ==
94+
method_named_parameters_.end()) {
95+
const std::string fqn_method_name = fqn_method_prefix + method_name;
96+
const std::string parameters_method_name =
97+
parameters_method_prefix + method_name;
9598

96-
// get names.
97-
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
98-
if (!fqn_res.ok()) {
99-
return fqn_res.error();
100-
}
101-
const auto& fqn_list = fqn_res.get();
99+
method_named_parameters_.insert({method_name, {}});
102100

103-
// get params start.
104-
auto param_res =
105-
executorch::extension::Module::execute(parameters_method_name);
106-
if (!param_res.ok()) {
107-
return param_res.error();
108-
}
101+
// get names.
102+
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
103+
if (!fqn_res.ok()) {
104+
return fqn_res.error();
105+
}
106+
const auto& fqn_list = fqn_res.get();
109107

110-
uint64_t param_start = param_res.get()[0].toInt();
108+
// get params start.
109+
auto param_res =
110+
executorch::extension::Module::execute(parameters_method_name);
111+
if (!param_res.ok()) {
112+
return param_res.error();
113+
}
111114

112-
auto e = executorch::extension::Module::load_method(method_name);
113-
if (e != runtime::Error::Ok) {
114-
return e;
115-
}
116-
auto& method = methods_.at(method_name).method;
117-
118-
// create dict
119-
size_t name_index = 0;
120-
for (size_t param_index = param_start; param_index < method->outputs_size();
121-
++param_index, ++name_index) {
122-
executorch::aten::string_view fqn = fqn_list.at(name_index).toString();
123-
executorch::aten::Tensor param = method->get_output(param_index).toTensor();
124-
named_parameters.insert({fqn, param});
115+
uint64_t param_start = param_res.get()[0].toInt();
116+
117+
// Load the method if it is not already loaded.
118+
auto e = executorch::extension::Module::load_method(method_name);
119+
if (e != runtime::Error::Ok) {
120+
return e;
121+
}
122+
auto& method = methods_.at(method_name).method;
123+
124+
// populate dict
125+
size_t name_index = 0;
126+
for (size_t param_index = param_start; param_index < method->outputs_size();
127+
++param_index, ++name_index) {
128+
executorch::aten::string_view fqn = fqn_list.at(name_index).toString();
129+
executorch::aten::Tensor param = method->get_output(param_index).toTensor();
130+
method_named_parameters_.at(method_name).insert({fqn, param});
131+
}
125132
}
126-
return named_parameters;
133+
return method_named_parameters_.at(method_name);
127134
}
128135

129136
runtime::Result<

extension/training/module/training_module.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@ class ET_EXPERIMENTAL TrainingModule final
3333
std::unique_ptr<runtime::DataLoader> data_loader,
3434
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
3535
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
36-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr)
36+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
37+
std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr)
3738
: executorch::extension::Module(
3839
std::move(data_loader),
3940
std::move(memory_allocator),
4041
std::move(temp_allocator),
41-
std::move(event_tracer)),
42-
method_named_gradients_({}) {}
42+
std::move(event_tracer),
43+
std::move(data_map_data_loader)),
44+
method_named_gradients_({}),
45+
method_named_parameters_({}) {}
4346

4447
explicit TrainingModule(const Module&) = delete;
4548
TrainingModule& operator=(const Module&) = delete;
@@ -97,6 +100,11 @@ class ET_EXPERIMENTAL TrainingModule final
97100
std::string,
98101
std::map<executorch::aten::string_view, executorch::aten::Tensor>>
99102
method_named_gradients_;
103+
104+
std::unordered_map<
105+
std::string,
106+
std::map<executorch::aten::string_view, executorch::aten::Tensor>>
107+
method_named_parameters_;
100108
};
101109

102110
} // namespace training

runtime/executor/tensor_parser_exec_aten.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -169,23 +169,8 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
169169
const executorch_flatbuffer::AllocationDetails* allocation_info =
170170
s_tensor->allocation_info();
171171

172-
// Memory Planned, with initial state
173-
if (data_buffer_idx > 0 && allocation_info != nullptr) {
174-
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
175-
if (!planned_ptr.ok()) {
176-
return planned_ptr.error();
177-
}
178-
auto err = TensorParser::load_mutable_subsegment_into(
179-
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
180-
181-
if (err != Error::Ok) {
182-
return err;
183-
}
184-
return planned_ptr;
185-
}
186-
187172
// External tensors.
188-
else if (
173+
if (
189174
s_tensor->extra_tensor_info() != nullptr &&
190175
s_tensor->extra_tensor_info()->location() ==
191176
executorch_flatbuffer::TensorDataLocation::EXTERNAL) {
@@ -232,10 +217,9 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
232217

233218
return planned_ptr;
234219
}
235-
}
236220

237221
// Constant, stored in PTE file.
238-
else if (data_buffer_idx > 0 && allocation_info == nullptr) {
222+
} else if (data_buffer_idx > 0 && allocation_info == nullptr) {
239223
auto const_data =
240224
program->get_constant_buffer_data(data_buffer_idx, nbytes);
241225
if (!const_data.ok()) {
@@ -246,7 +230,21 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
246230
// guarantee that this data is never modified.
247231
return const_cast<void*>(const_data.get());
248232

249-
// Memory planned, no initial state
233+
// Memory Planned, with initial state
234+
} else if (data_buffer_idx > 0 && allocation_info != nullptr) {
235+
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
236+
if (!planned_ptr.ok()) {
237+
return planned_ptr.error();
238+
}
239+
auto err = TensorParser::load_mutable_subsegment_into(
240+
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
241+
242+
if (err != Error::Ok) {
243+
return err;
244+
}
245+
return planned_ptr;
246+
247+
// Memory planned, no initial state
250248
} else if (data_buffer_idx == 0 && allocation_info != nullptr) {
251249
return getMemPlannedPtr(allocation_info, nbytes, allocator);
252250

test/models/export_program.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def main() -> None:
276276
prog.write_to_file(fp)
277277
print(f"Exported {module_name} and wrote program data to {outfile}")
278278

279+
if args.external_constants:
280+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
281+
prog._tensor_data[f"{module_name}"] = prog._tensor_data.pop("_default_external_constant")
279282
prog.write_tensor_data_to_file(args.outdir)
280283

281284

test/models/targets.bzl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,21 @@ def define_common_targets():
9090
# case, and typically shouldn't be done.
9191
_is_external_target = True,
9292
)
93+
94+
# Class names of nn.Modules for :exported_programs to export.
95+
MODULES_AND_DATA_TO_EXPORT = [
96+
"ModuleLinear",
97+
"ModuleSimpleTrain",
98+
]
9399

94100
runtime.genrule(
95101
name = "exported_program_and_data",
96-
cmd = "$(exe :export_program) --modules ModuleLinear --external-constants --outdir $OUT",
102+
cmd = "$(exe :export_program) --modules " + ",".join(MODULES_AND_DATA_TO_EXPORT) + " --external-constants --outdir $OUT",
97103
outs = {
98104
"ModuleLinear.pte": ["ModuleLinearProgram.pte"],
99-
"ModuleLinear.ptd": ["_default_external_constant.ptd"],
105+
"ModuleLinear.ptd": ["ModuleLinearProgram.ptd"],
106+
"ModuleSimpleTrainProgram.pte": ["ModuleSimpleTrainProgram.pte"],
107+
"ModuleSimpleTrain.ptd": ["ModuleSimpleTrainProgram.ptd"],
100108
},
101109
default_outs = ["."],
102110
visibility = [

0 commit comments

Comments
 (0)