Skip to content

Commit b28fe99

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
add .ptd support to extension/module (pytorch#8421)
Summary: Pull Request resolved: pytorch#8421 Title Reviewed By: lucylq Differential Revision: D69478424
1 parent 148832e commit b28fe99

File tree

10 files changed

+107
-9
lines changed

10 files changed

+107
-9
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
254254
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
255255
endif()
256256

257+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
258+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
259+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
260+
endif()
261+
257262
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
258263
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
259264
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)

extension/flat_tensor/targets.bzl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ def define_common_targets():
99
exported_headers = ["flat_tensor_data_map.h"],
1010
deps = [
1111
"//executorch/extension/flat_tensor/serialize:generated_headers",
12-
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
1312
"//executorch/runtime/core:core",
1413
"//executorch/runtime/core:evalue",
1514
"//executorch/runtime/core:named_data_map",
1615
"//executorch/runtime/core/exec_aten:lib",
1716
"//executorch/runtime/core/exec_aten/util:tensor_util",
1817
],
18+
exported_deps = [
19+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
20+
],
1921
visibility = [
2022
"//executorch/...",
2123
],

extension/module/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ if(CMAKE_TOOLCHAIN_IOS
2727
else()
2828
add_library(extension_module SHARED ${_extension_module__srcs})
2929
endif()
30-
target_link_libraries(extension_module PRIVATE executorch extension_data_loader)
30+
target_link_libraries(extension_module PRIVATE executorch extension_data_loader extension_flat_tensor)
3131
target_include_directories(extension_module PUBLIC ${EXECUTORCH_ROOT}/..)
3232
target_compile_options(
3333
extension_module PUBLIC -Wno-deprecated-declarations -fPIC
@@ -37,7 +37,7 @@ target_compile_options(
3737
# after cleaning up CMake targets.
3838
add_library(extension_module_static STATIC ${_extension_module__srcs})
3939
target_link_libraries(
40-
extension_module_static PRIVATE executorch extension_data_loader
40+
extension_module_static PRIVATE executorch extension_data_loader extension_flat_tensor
4141
)
4242
target_include_directories(extension_module_static PUBLIC ${EXECUTORCH_ROOT}/..)
4343
target_compile_options(

extension/module/module.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,24 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125

126126
runtime::Error Module::load_method(
127127
const std::string& method_name,
128-
torch::executor::EventTracer* event_tracer) {
128+
torch::executor::EventTracer* event_tracer,
129+
std::unique_ptr<runtime::DataLoader> data_map_data_loader) {
129130
if (!is_method_loaded(method_name)) {
130131
ET_CHECK_OK_OR_RETURN_ERROR(load());
131132

132133
MethodHolder method_holder;
134+
135+
// If we have a .ptd load it.
136+
const runtime::NamedDataMap* named_data_map = nullptr;
137+
if (data_map_data_loader != nullptr) {
138+
auto data_map =
139+
ET_UNWRAP_UNIQUE(executorch::extension::FlatTensorDataMap::load(
140+
data_map_data_loader.get()));
141+
method_holder.data_map_loader = std::move(data_map_data_loader);
142+
method_holder.data_map = std::move(data_map);
143+
}
144+
named_data_map = method_holder.data_map.get();
145+
133146
const auto method_metadata =
134147
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135148
const auto planned_buffersCount =
@@ -155,10 +168,30 @@ runtime::Error Module::load_method(
155168
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156169
method_name.c_str(),
157170
method_holder.memory_manager.get(),
158-
event_tracer ? event_tracer : this->event_tracer()));
171+
event_tracer ? event_tracer : this->event_tracer(),
172+
named_data_map));
159173
method_holder.inputs.resize(method_holder.method->inputs_size());
160174
methods_.emplace(method_name, std::move(method_holder));
161175
}
176+
return runtime::Error::Ok;
177+
}
178+
179+
runtime::Error Module::load_method(
180+
const std::string& method_name,
181+
const std::string& data_map_path,
182+
torch::executor::EventTracer* event_tracer) {
183+
if (!is_method_loaded(method_name)) {
184+
ET_CHECK_OK_OR_RETURN_ERROR(load());
185+
186+
// If we have a .ptd get a dataloader for it.
187+
std::unique_ptr<extension::FileDataLoader> data_map_data_loader;
188+
if (!data_map_path.empty()) {
189+
data_map_data_loader =
190+
ET_UNWRAP_UNIQUE(FileDataLoader::from(data_map_path.c_str()));
191+
}
192+
193+
return load_method(method_name, event_tracer, std::move(data_map_data_loader));
194+
}
162195
return runtime::Error::Ok;
163196
}
164197

extension/module/module.h

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <unordered_set>
1515
#include <vector>
1616

17+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1718
#include <executorch/runtime/executor/program.h>
1819

1920
namespace executorch {
@@ -133,6 +134,8 @@ class Module {
133134
* needed. The loaded method is cached to reuse the next time it's executed.
134135
*
135136
* @param[in] method_name The name of the method to load.
137+
* @param[in] data_map_path Path to a .ptd file containing weights
138+
* for this method.
136139
* @param[in] event_tracer Per-method event tracer to profile/trace methods
137140
* individually. When not given, the event tracer passed to the Module
138141
* constructor is used. Otherwise, this per-method event tracer takes
@@ -143,8 +146,29 @@ class Module {
143146
ET_NODISCARD
144147
runtime::Error load_method(
145148
const std::string& method_name,
149+
const std::string& data_map_path,
146150
torch::executor::EventTracer* event_tracer = nullptr);
147151

152+
/**
153+
* Load a specific method from the program and set up memory management if
154+
* needed. The loaded method is cached to reuse the next time it's executed.
155+
*
156+
* @param[in] method_name The name of the method to load.
157+
* @param[in] event_tracer Per-method event tracer to profile/trace methods
158+
* individually. When not given, the event tracer passed to the Module
159+
* constructor is used. Otherwise, this per-method event tracer takes
160+
* precedence.
161+
* @param[in] data_map_data_loader Optional data loader for the .ptd file
162+
* for this method.
163+
*
164+
* @returns An Error to indicate success or failure.
165+
*/
166+
ET_NODISCARD
167+
runtime::Error load_method(
168+
const std::string& method_name,
169+
torch::executor::EventTracer* event_tracer = nullptr,
170+
std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr);
171+
148172
/**
149173
* Load the 'forward' method from the program and set up memory management if
150174
* needed. The loaded method is cached to reuse the next time it's executed.
@@ -155,8 +179,9 @@ class Module {
155179
* @returns An Error to indicate success or failure.
156180
*/
157181
ET_NODISCARD inline runtime::Error load_forward(
158-
torch::executor::EventTracer* event_tracer = nullptr) {
159-
return load_method("forward", event_tracer);
182+
torch::executor::EventTracer* event_tracer = nullptr,
183+
std::unique_ptr<runtime::DataLoader> data_map_data_loader = nullptr) {
184+
return load_method("forward", event_tracer, std::move(data_map_data_loader));
160185
}
161186

162187
/**
@@ -430,10 +455,11 @@ class Module {
430455
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
431456
std::unique_ptr<runtime::MemoryManager> memory_manager;
432457
std::unique_ptr<runtime::Method> method;
458+
std::unique_ptr<runtime::DataLoader> data_map_loader;
459+
std::unique_ptr<extension::FlatTensorDataMap> data_map;
433460
std::vector<runtime::EValue> inputs;
434461
};
435462

436-
private:
437463
std::string file_path_;
438464
LoadMode load_mode_{LoadMode::MmapUseMlock};
439465
std::shared_ptr<runtime::Program> program_;

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ def define_common_targets():
2828
],
2929
exported_deps = [
3030
"//executorch/runtime/executor:program" + aten_suffix,
31+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
3132
],
3233
)

extension/module/test/module_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,19 @@ class ModuleTest : public ::testing::Test {
2323
protected:
2424
static void SetUpTestSuite() {
2525
model_path_ = std::getenv("RESOURCES_PATH") + std::string("/add.pte");
26+
linear_path_ = std::getenv("RESOURCES_PATH") + std::string("/linear.pte");
27+
linear_data_path_ =
28+
std::getenv("RESOURCES_PATH") + std::string("/linear.ptd");
2629
}
2730

2831
static std::string model_path_;
32+
static std::string linear_path_;
33+
static std::string linear_data_path_;
2934
};
3035

3136
std::string ModuleTest::model_path_;
37+
std::string ModuleTest::linear_path_;
38+
std::string ModuleTest::linear_data_path_;
3239

3340
TEST_F(ModuleTest, TestLoad) {
3441
Module module(model_path_);
@@ -435,3 +442,15 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
435442

436443
EXPECT_NE(module.set_output(EValue()), Error::Ok);
437444
}
445+
446+
TEST_F(ModuleTest, TestPTD) {
447+
Module module(linear_path_);
448+
449+
ASSERT_EQ(
450+
module.load_method("forward", linear_data_path_), Error::Ok);
451+
452+
auto tensor1 =
453+
make_tensor_ptr({3, 3}, {2.f, 3.f, 4.f, 2.f, 3.f, 4.f, 2.f, 3.f, 4.f});
454+
455+
ASSERT_EQ(module.forward(tensor1).error(), Error::Ok);
456+
}
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
## Resources
22

3-
### model.pte
3+
### add.pte, linear.pte, linear.ptd
44
- Internally generated after D62209852, 2024-09-06 with:
55
```
66
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="add"
77
```
8+
9+
and
10+
11+
```
12+
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="linear" -examples
13+
```
814
- In OSS, the same file can be generated after [#5145](https://github.com/pytorch/executorch/pull/5145), 2024-09-06 with:
915
```
1016
python -m examples.portable.scripts.export --model_name="add"
1117
```
18+
19+
and
20+
21+
```
22+
python -m examples.portable.scripts.export --model_name="linear" -e
23+
```
336 Bytes
Binary file not shown.
1.18 KB
Binary file not shown.

0 commit comments

Comments
 (0)