Skip to content

Commit 83ceb0f

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
add .ptd support to extension/module (pytorch#8421)
Summary: Pull Request resolved: pytorch#8421 Title Reviewed By: lucylq Differential Revision: D69478424
1 parent 89dc36c commit 83ceb0f

File tree

10 files changed

+69
-10
lines changed

10 files changed

+69
-10
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
254254
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
255255
endif()
256256

257+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
258+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
259+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
260+
endif()
261+
257262
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
258263
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
259264
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)

extension/flat_tensor/targets.bzl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ def define_common_targets():
99
exported_headers = ["flat_tensor_data_map.h"],
1010
deps = [
1111
"//executorch/extension/flat_tensor/serialize:generated_headers",
12-
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
1312
"//executorch/runtime/core:core",
1413
"//executorch/runtime/core:evalue",
1514
"//executorch/runtime/core:named_data_map",
1615
"//executorch/runtime/core/exec_aten:lib",
1716
"//executorch/runtime/core/exec_aten/util:tensor_util",
1817
],
18+
exported_deps = [
19+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
20+
],
1921
visibility = [
2022
"//executorch/...",
2123
],

extension/module/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ if(CMAKE_TOOLCHAIN_IOS
2727
else()
2828
add_library(extension_module SHARED ${_extension_module__srcs})
2929
endif()
30-
target_link_libraries(extension_module PRIVATE executorch extension_data_loader)
30+
target_link_libraries(extension_module PRIVATE executorch extension_data_loader extension_flat_tensor)
3131
target_include_directories(extension_module PUBLIC ${EXECUTORCH_ROOT}/..)
3232
target_compile_options(
3333
extension_module PUBLIC -Wno-deprecated-declarations -fPIC
@@ -37,7 +37,7 @@ target_compile_options(
3737
# after cleaning up CMake targets.
3838
add_library(extension_module_static STATIC ${_extension_module__srcs})
3939
target_link_libraries(
40-
extension_module_static PRIVATE executorch extension_data_loader
40+
extension_module_static PRIVATE executorch extension_data_loader extension_flat_tensor
4141
)
4242
target_include_directories(extension_module_static PUBLIC ${EXECUTORCH_ROOT}/..)
4343
target_compile_options(

extension/module/module.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,26 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
125125

126126
runtime::Error Module::load_method(
127127
const std::string& method_name,
128-
torch::executor::EventTracer* event_tracer) {
128+
torch::executor::EventTracer* event_tracer,
129+
const std::string& data_map_path) {
129130
if (!is_method_loaded(method_name)) {
130131
ET_CHECK_OK_OR_RETURN_ERROR(load());
131132

132133
MethodHolder method_holder;
134+
135+
// If we have a .ptd load it.
136+
const runtime::NamedDataMap* named_data_map = nullptr;
137+
if (!data_map_path.empty()) {
138+
auto data_map_data_loader =
139+
ET_UNWRAP_UNIQUE(FileDataLoader::from(data_map_path.c_str()));
140+
auto data_map =
141+
ET_UNWRAP_UNIQUE(executorch::extension::FlatTensorDataMap::load(
142+
data_map_data_loader.get()));
143+
method_holder.data_map_loader = std::move(data_map_data_loader);
144+
method_holder.data_map = std::move(data_map);
145+
}
146+
named_data_map = method_holder.data_map.get();
147+
133148
const auto method_metadata =
134149
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135150
const auto planned_buffersCount =
@@ -155,7 +170,8 @@ runtime::Error Module::load_method(
155170
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156171
method_name.c_str(),
157172
method_holder.memory_manager.get(),
158-
event_tracer ? event_tracer : this->event_tracer()));
173+
event_tracer ? event_tracer : this->event_tracer(),
174+
named_data_map));
159175
method_holder.inputs.resize(method_holder.method->inputs_size());
160176
methods_.emplace(method_name, std::move(method_holder));
161177
}

extension/module/module.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <unordered_set>
1515
#include <vector>
1616

17+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1718
#include <executorch/runtime/executor/program.h>
1819

1920
namespace executorch {
@@ -143,7 +144,8 @@ class Module {
143144
ET_NODISCARD
144145
runtime::Error load_method(
145146
const std::string& method_name,
146-
torch::executor::EventTracer* event_tracer = nullptr);
147+
torch::executor::EventTracer* event_tracer = nullptr,
148+
const std::string& data_map_path = "");
147149

148150
/**
149151
* Load the 'forward' method from the program and set up memory management if
@@ -155,8 +157,9 @@ class Module {
155157
* @returns An Error to indicate success or failure.
156158
*/
157159
ET_NODISCARD inline runtime::Error load_forward(
158-
torch::executor::EventTracer* event_tracer = nullptr) {
159-
return load_method("forward", event_tracer);
160+
torch::executor::EventTracer* event_tracer = nullptr,
161+
const std::string& data_map_path = "") {
162+
return load_method("forward", event_tracer, data_map_path);
160163
}
161164

162165
/**
@@ -430,10 +433,11 @@ class Module {
430433
std::unique_ptr<runtime::HierarchicalAllocator> planned_memory;
431434
std::unique_ptr<runtime::MemoryManager> memory_manager;
432435
std::unique_ptr<runtime::Method> method;
436+
std::unique_ptr<runtime::DataLoader> data_map_loader;
437+
std::unique_ptr<extension::FlatTensorDataMap> data_map;
433438
std::vector<runtime::EValue> inputs;
434439
};
435440

436-
private:
437441
std::string file_path_;
438442
LoadMode load_mode_{LoadMode::MmapUseMlock};
439443
std::shared_ptr<runtime::Program> program_;

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ def define_common_targets():
2828
],
2929
exported_deps = [
3030
"//executorch/runtime/executor:program" + aten_suffix,
31+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
3132
],
3233
)

extension/module/test/module_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,19 @@ class ModuleTest : public ::testing::Test {
2323
protected:
2424
static void SetUpTestSuite() {
2525
model_path_ = std::getenv("RESOURCES_PATH") + std::string("/add.pte");
26+
linear_path_ = std::getenv("RESOURCES_PATH") + std::string("/linear.pte");
27+
linear_data_path_ =
28+
std::getenv("RESOURCES_PATH") + std::string("/linear.ptd");
2629
}
2730

2831
static std::string model_path_;
32+
static std::string linear_path_;
33+
static std::string linear_data_path_;
2934
};
3035

3136
std::string ModuleTest::model_path_;
37+
std::string ModuleTest::linear_path_;
38+
std::string ModuleTest::linear_data_path_;
3239

3340
TEST_F(ModuleTest, TestLoad) {
3441
Module module(model_path_);
@@ -435,3 +442,15 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
435442

436443
EXPECT_NE(module.set_output(EValue()), Error::Ok);
437444
}
445+
446+
TEST_F(ModuleTest, TestPTD) {
447+
Module module(linear_path_);
448+
449+
ASSERT_EQ(
450+
module.load_method("forward", nullptr, linear_data_path_), Error::Ok);
451+
452+
auto tensor1 =
453+
make_tensor_ptr({3, 3}, {2.f, 3.f, 4.f, 2.f, 3.f, 4.f, 2.f, 3.f, 4.f});
454+
455+
ASSERT_EQ(module.forward(tensor1).error(), Error::Ok);
456+
}
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
## Resources
22

3-
### model.pte
3+
### add.pte, linear.pte, linear.ptd
44
- Internally generated after D62209852, 2024-09-06 with:
55
```
66
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="add"
77
```
8+
9+
and
10+
11+
```
12+
buck2 run fbcode//executorch/examples/portable/scripts:export -- --model_name="linear" -examples
13+
```
814
- In OSS, the same file can be generated after [#5145](https://github.com/pytorch/executorch/pull/5145), 2024-09-06 with:
915
```
1016
python -m examples.portable.scripts.export --model_name="add"
1117
```
18+
19+
and
20+
21+
```
22+
python -m examples.portable.scripts.export --model_name="linear" -e
23+
```
336 Bytes
Binary file not shown.
1.18 KB
Binary file not shown.

0 commit comments

Comments
 (0)