Skip to content

Commit 1d0e0ee

Browse files
training module takes .ptd
Differential Revision: D69547105 Pull Request resolved: #8739
1 parent 0ab3499 commit 1d0e0ee

File tree

9 files changed

+147
-69
lines changed

9 files changed

+147
-69
lines changed

extension/flat_tensor/test/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ include(${EXECUTORCH_ROOT}/build/Test.cmake)
2020

2121
add_custom_command(
2222
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
23-
"${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
23+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
2424
COMMAND
2525
python -m test.models.export_program --modules "ModuleLinear"
2626
--external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}" 2> /dev/null
@@ -30,12 +30,12 @@ add_custom_command(
3030
add_custom_target(
3131
extension_flat_tensor_test_resources
3232
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
33-
"${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
33+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
3434
)
3535

3636
set(test_env
3737
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
38-
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
38+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
3939
)
4040

4141
set(_test_srcs flat_tensor_data_map_test.cpp flat_tensor_header_test.cpp)

extension/training/module/test/targets.bzl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ def define_common_targets(is_fbcode = False):
1616
# an fbcode target path because the authoring/export tools
1717
# intentionally don't work in xplat (since they're host-only tools).
1818
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
19-
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
20-
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
2119
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
20+
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
21+
"ET_MODULE_TRAIN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleSimpleTrain.ptd])",
22+
"ET_MODULE_TRAIN_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleSimpleTrainProgram.pte])",
23+
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
2224
}
2325

2426
runtime.cxx_test(
@@ -29,6 +31,7 @@ def define_common_targets(is_fbcode = False):
2931
deps = [
3032
"//executorch/extension/training/module:training_module",
3133
"//executorch/extension/data_loader:file_data_loader",
34+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
3235
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
3336
"//executorch/kernels/portable:generated_lib",
3437
],

extension/training/module/test/training_module_test.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1011
#include <executorch/extension/training/module/training_module.h>
1112

1213
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
@@ -18,9 +19,17 @@
1819
using namespace ::testing;
1920
using executorch::aten::ScalarType;
2021
using executorch::aten::Tensor;
22+
using executorch::extension::FlatTensorDataMap;
23+
using executorch::extension::FlatTensorHeader;
24+
using executorch::runtime::DataLoader;
25+
using executorch::runtime::Error;
26+
using executorch::runtime::FreeableBuffer;
27+
using executorch::runtime::Result;
28+
using executorch::runtime::TensorLayout;
2129
using torch::executor::Error;
2230
using torch::executor::Span;
2331
using torch::executor::testing::TensorFactory;
32+
using torch::executor::util::FileDataLoader;
2433

2534
class TrainingModuleTest : public ::testing::Test {
2635
protected:
@@ -105,3 +114,42 @@ TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
105114
auto res = mod.execute_forward_backward("forward", inputs);
106115
ASSERT_EQ(res.error(), Error::InvalidArgument);
107116
}
117+
118+
TEST_F(TrainingModuleTest, SeperateDataTest) {
119+
// Load data map.
120+
// The eager linear model is defined at:
121+
// //executorch/test/models/linear_model.py
122+
const char* ptd_path = std::getenv("ET_MODULE_TRAIN_DATA_PATH");
123+
Result<FileDataLoader> data_map_loader_res = FileDataLoader::from(ptd_path);
124+
ASSERT_EQ(data_map_loader_res.error(), Error::Ok);
125+
126+
auto data_map_loader =
127+
std::make_unique<torch::executor::util::FileDataLoader>(
128+
std::move(data_map_loader_res.get()));
129+
130+
const char* pte_path = std::getenv("ET_MODULE_TRAIN_PROGRAM_PATH");
131+
Result<FileDataLoader> pte_loader_res = FileDataLoader::from(pte_path);
132+
ASSERT_EQ(pte_loader_res.error(), Error::Ok);
133+
134+
auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
135+
std::move(pte_loader_res.get()));
136+
137+
auto mod = executorch::extension::training::TrainingModule(
138+
std::move(pte_loader),
139+
nullptr,
140+
nullptr,
141+
nullptr,
142+
std::move(data_map_loader));
143+
144+
TensorFactory<ScalarType::Float> tf;
145+
Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
146+
Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
147+
148+
std::vector<executorch::runtime::EValue> inputs;
149+
inputs.push_back(input);
150+
inputs.push_back(label);
151+
152+
auto res = mod.execute_forward_backward("forward", inputs);
153+
ASSERT_EQ(res.error(), Error::Ok);
154+
ASSERT_EQ(res.get().size(), 1);
155+
}

extension/training/module/training_module.cpp

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ TrainingModule::execute_forward_backward(
4343
uint64_t param_start = param_res.get()[0].toInt();
4444

4545
// Execute the forward and backward pass.
46-
4746
auto outputs = torch::executor::Module::execute(method_name, input);
4847
if (!outputs.ok()) {
4948
return outputs.error();
@@ -56,19 +55,23 @@ TrainingModule::execute_forward_backward(
5655
user_outputs.push_back(outputs.get().at(i));
5756
}
5857

59-
// Extract and store the gradients.
58+
// Extract and store the gradients and params if this is the first time seeing
59+
// this method.
6060
if (method_named_gradients_.find(method_name) ==
6161
method_named_gradients_.end()) {
62+
// Fully qualified names
63+
std::vector<runtime::EValue> fqn_list;
6264
method_named_gradients_.insert({method_name, {}});
6365

6466
auto& gradients_map = method_named_gradients_.at(method_name);
65-
// Get names.
67+
68+
// Get names if we havent seen this method before.
6669
const std::string fqn_method_name = fqn_method_prefix + method_name;
6770
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
6871
if (!fqn_res.ok()) {
6972
return fqn_res.error();
7073
}
71-
const auto& fqn_list = fqn_res.get();
74+
fqn_list = fqn_res.get();
7275

7376
// Only have to initialize the dict once because the tensors in the dict and
7477
// the tensors in the method alias the same TensorImpl, so updating one will
@@ -87,43 +90,49 @@ TrainingModule::execute_forward_backward(
8790
runtime::Result<
8891
const std::map<executorch::aten::string_view, executorch::aten::Tensor>>
8992
TrainingModule::named_parameters(const std::string& method_name) {
90-
std::map<executorch::aten::string_view, executorch::aten::Tensor>
91-
named_parameters;
92-
const std::string fqn_method_name = fqn_method_prefix + method_name;
93-
const std::string parameters_method_name =
94-
parameters_method_prefix + method_name;
93+
// If we haven't seen this method before, populate the dict.
94+
if (method_named_parameters_.find(method_name) ==
95+
method_named_parameters_.end()) {
96+
const std::string fqn_method_name = fqn_method_prefix + method_name;
97+
const std::string parameters_method_name =
98+
parameters_method_prefix + method_name;
9599

96-
// get names.
97-
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
98-
if (!fqn_res.ok()) {
99-
return fqn_res.error();
100-
}
101-
const auto& fqn_list = fqn_res.get();
100+
method_named_parameters_.insert({method_name, {}});
102101

103-
// get params start.
104-
auto param_res =
105-
executorch::extension::Module::execute(parameters_method_name);
106-
if (!param_res.ok()) {
107-
return param_res.error();
108-
}
102+
// get names.
103+
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
104+
if (!fqn_res.ok()) {
105+
return fqn_res.error();
106+
}
107+
const auto& fqn_list = fqn_res.get();
109108

110-
uint64_t param_start = param_res.get()[0].toInt();
109+
// get params start.
110+
auto param_res =
111+
executorch::extension::Module::execute(parameters_method_name);
112+
if (!param_res.ok()) {
113+
return param_res.error();
114+
}
111115

112-
auto e = executorch::extension::Module::load_method(method_name);
113-
if (e != runtime::Error::Ok) {
114-
return e;
115-
}
116-
auto& method = methods_.at(method_name).method;
117-
118-
// create dict
119-
size_t name_index = 0;
120-
for (size_t param_index = param_start; param_index < method->outputs_size();
121-
++param_index, ++name_index) {
122-
executorch::aten::string_view fqn = fqn_list.at(name_index).toString();
123-
executorch::aten::Tensor param = method->get_output(param_index).toTensor();
124-
named_parameters.insert({fqn, param});
116+
uint64_t param_start = param_res.get()[0].toInt();
117+
118+
// Load the method if it is not already loaded.
119+
auto e = executorch::extension::Module::load_method(method_name);
120+
if (e != runtime::Error::Ok) {
121+
return e;
122+
}
123+
auto& method = methods_.at(method_name).method;
124+
125+
// populate dict
126+
size_t name_index = 0;
127+
for (size_t param_index = param_start; param_index < method->outputs_size();
128+
++param_index, ++name_index) {
129+
executorch::aten::string_view fqn = fqn_list.at(name_index).toString();
130+
executorch::aten::Tensor param =
131+
method->get_output(param_index).toTensor();
132+
method_named_parameters_.at(method_name).insert({fqn, param});
133+
}
125134
}
126-
return named_parameters;
135+
return method_named_parameters_.at(method_name);
127136
}
128137

129138
runtime::Result<

extension/training/module/training_module.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@ class ET_EXPERIMENTAL TrainingModule final
3333
std::unique_ptr<runtime::DataLoader> data_loader,
3434
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
3535
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
36-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr)
36+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
37+
std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr)
3738
: executorch::extension::Module(
3839
std::move(data_loader),
3940
std::move(memory_allocator),
4041
std::move(temp_allocator),
41-
std::move(event_tracer)),
42-
method_named_gradients_({}) {}
42+
std::move(event_tracer),
43+
std::move(data_map_data_loader)),
44+
method_named_gradients_({}),
45+
method_named_parameters_({}) {}
4346

4447
explicit TrainingModule(const Module&) = delete;
4548
TrainingModule& operator=(const Module&) = delete;
@@ -97,6 +100,11 @@ class ET_EXPERIMENTAL TrainingModule final
97100
std::string,
98101
std::map<executorch::aten::string_view, executorch::aten::Tensor>>
99102
method_named_gradients_;
103+
104+
std::unordered_map<
105+
std::string,
106+
std::map<executorch::aten::string_view, executorch::aten::Tensor>>
107+
method_named_parameters_;
100108
};
101109

102110
} // namespace training

runtime/executor/tensor_parser_exec_aten.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -169,24 +169,8 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
169169
const executorch_flatbuffer::AllocationDetails* allocation_info =
170170
s_tensor->allocation_info();
171171

172-
// Memory Planned, with initial state
173-
if (data_buffer_idx > 0 && allocation_info != nullptr) {
174-
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
175-
if (!planned_ptr.ok()) {
176-
return planned_ptr.error();
177-
}
178-
auto err = TensorParser::load_mutable_subsegment_into(
179-
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
180-
181-
if (err != Error::Ok) {
182-
return err;
183-
}
184-
return planned_ptr;
185-
}
186-
187172
// External tensors.
188-
else if (
189-
s_tensor->extra_tensor_info() != nullptr &&
173+
if (s_tensor->extra_tensor_info() != nullptr &&
190174
s_tensor->extra_tensor_info()->location() ==
191175
executorch_flatbuffer::TensorDataLocation::EXTERNAL) {
192176
// Check that fqn is not null.
@@ -232,10 +216,9 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
232216

233217
return planned_ptr;
234218
}
235-
}
236219

237-
// Constant, stored in PTE file.
238-
else if (data_buffer_idx > 0 && allocation_info == nullptr) {
220+
// Constant, stored in PTE file.
221+
} else if (data_buffer_idx > 0 && allocation_info == nullptr) {
239222
auto const_data =
240223
program->get_constant_buffer_data(data_buffer_idx, nbytes);
241224
if (!const_data.ok()) {
@@ -246,6 +229,20 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
246229
// guarantee that this data is never modified.
247230
return const_cast<void*>(const_data.get());
248231

232+
// Memory Planned, with initial state
233+
} else if (data_buffer_idx > 0 && allocation_info != nullptr) {
234+
auto planned_ptr = getMemPlannedPtr(allocation_info, nbytes, allocator);
235+
if (!planned_ptr.ok()) {
236+
return planned_ptr.error();
237+
}
238+
auto err = TensorParser::load_mutable_subsegment_into(
239+
program, 0, s_tensor->data_buffer_idx(), nbytes, planned_ptr.get());
240+
241+
if (err != Error::Ok) {
242+
return err;
243+
}
244+
return planned_ptr;
245+
249246
// Memory planned, no initial state
250247
} else if (data_buffer_idx == 0 && allocation_info != nullptr) {
251248
return getMemPlannedPtr(allocation_info, nbytes, allocator);

runtime/executor/test/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ add_custom_command(
2424
"${CMAKE_CURRENT_BINARY_DIR}/ModuleIndex.pte"
2525
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinear.pte"
2626
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
27-
"${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
27+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
2828
"${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
2929
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
3030
COMMAND
@@ -48,7 +48,7 @@ add_custom_target(
4848
"${CMAKE_CURRENT_BINARY_DIR}/ModuleIndex.pte"
4949
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinear.pte"
5050
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
51-
"${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
51+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
5252
"${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
5353
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
5454
)
@@ -61,7 +61,7 @@ set(test_env
6161
"ET_MODULE_INDEX_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleIndex.pte"
6262
"ET_MODULE_LINEAR_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinear.pte"
6363
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
64-
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/_default_external_constant.ptd"
64+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
6565
"ET_MODULE_MULTI_ENTRY_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleMultipleEntry.pte"
6666
"ET_MODULE_SIMPLE_TRAIN_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrain.pte"
6767
)

test/models/export_program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ def main() -> None:
276276
prog.write_to_file(fp)
277277
print(f"Exported {module_name} and wrote program data to {outfile}")
278278

279+
if args.external_constants:
280+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
281+
prog._tensor_data[f"{module_name}"] = prog._tensor_data.pop(
282+
"_default_external_constant"
283+
)
279284
prog.write_tensor_data_to_file(args.outdir)
280285

281286

test/models/targets.bzl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,21 @@ def define_common_targets():
9090
# case, and typically shouldn't be done.
9191
_is_external_target = True,
9292
)
93+
94+
# Class names of nn.Modules for :exported_programs to export.
95+
MODULES_AND_DATA_TO_EXPORT = [
96+
"ModuleLinear",
97+
"ModuleSimpleTrain",
98+
]
9399

94100
runtime.genrule(
95101
name = "exported_program_and_data",
96-
cmd = "$(exe :export_program) --modules ModuleLinear --external-constants --outdir $OUT",
102+
cmd = "$(exe :export_program) --modules " + ",".join(MODULES_AND_DATA_TO_EXPORT) + " --external-constants --outdir $OUT",
97103
outs = {
98104
"ModuleLinear.pte": ["ModuleLinearProgram.pte"],
99-
"ModuleLinear.ptd": ["_default_external_constant.ptd"],
105+
"ModuleLinear.ptd": ["ModuleLinearProgram.ptd"],
106+
"ModuleSimpleTrainProgram.pte": ["ModuleSimpleTrainProgram.pte"],
107+
"ModuleSimpleTrain.ptd": ["ModuleSimpleTrainProgram.ptd"],
100108
},
101109
default_outs = ["."],
102110
visibility = [

0 commit comments

Comments
 (0)