Skip to content

Commit ab625d3

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
add .ptd support to extension/module (pytorch#8421)
Summary: Pull Request resolved: pytorch#8421 Title Reviewed By: lucylq Differential Revision: D69478424
1 parent 00c1443 commit ab625d3

File tree

10 files changed

+191
-38
lines changed

10 files changed

+191
-38
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
258258
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
259259
endif()
260260

261+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
262+
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
263+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
264+
endif()
265+
261266
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
262267
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
263268
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)

extension/flat_tensor/targets.bzl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ def define_common_targets():
99
exported_headers = ["flat_tensor_data_map.h"],
1010
deps = [
1111
"//executorch/extension/flat_tensor/serialize:generated_headers",
12-
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
1312
"//executorch/runtime/core:core",
1413
"//executorch/runtime/core:evalue",
1514
"//executorch/runtime/core:named_data_map",
1615
"//executorch/runtime/core/exec_aten:lib",
1716
"//executorch/runtime/core/exec_aten/util:tensor_util",
1817
],
18+
exported_deps = [
19+
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
20+
],
1921
visibility = [
2022
"//executorch/...",
2123
],

extension/module/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ if(CMAKE_TOOLCHAIN_IOS
2727
else()
2828
add_library(extension_module SHARED ${_extension_module__srcs})
2929
endif()
30-
target_link_libraries(extension_module PRIVATE executorch extension_data_loader)
30+
target_link_libraries(extension_module PRIVATE executorch extension_data_loader extension_flat_tensor)
3131
target_include_directories(extension_module PUBLIC ${EXECUTORCH_ROOT}/..)
3232
target_compile_options(
3333
extension_module PUBLIC -Wno-deprecated-declarations -fPIC
@@ -37,7 +37,7 @@ target_compile_options(
3737
# after cleaning up CMake targets.
3838
add_library(extension_module_static STATIC ${_extension_module__srcs})
3939
target_link_libraries(
40-
extension_module_static PRIVATE executorch extension_data_loader
40+
extension_module_static PRIVATE executorch extension_data_loader extension_flat_tensor
4141
)
4242
target_include_directories(extension_module_static PUBLIC ${EXECUTORCH_ROOT}/..)
4343
target_compile_options(

extension/module/module.cpp

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/extension/data_loader/file_data_loader.h>
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
13+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1314
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
1415
#include <executorch/runtime/platform/runtime.h>
1516

@@ -36,73 +37,134 @@
3637
namespace executorch {
3738
namespace extension {
3839

40+
namespace {
41+
runtime::Result<std::unique_ptr<runtime::DataLoader>> load_file(
42+
const std::string& file_path,
43+
Module::LoadMode mode) {
44+
std::unique_ptr<runtime::DataLoader> res = nullptr;
45+
switch (mode) {
46+
case Module::LoadMode::File:
47+
res =
48+
ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path.c_str()));
49+
break;
50+
case Module::LoadMode::Mmap:
51+
res = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
52+
file_path.c_str(), MmapDataLoader::MlockConfig::NoMlock));
53+
break;
54+
case Module::LoadMode::MmapUseMlock:
55+
res =
56+
ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path.c_str()));
57+
break;
58+
case Module::LoadMode::MmapUseMlockIgnoreErrors:
59+
res = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
60+
file_path.c_str(),
61+
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
62+
break;
63+
}
64+
return res;
65+
}
66+
}
67+
3968
Module::Module(
4069
const std::string& file_path,
4170
const LoadMode load_mode,
4271
std::unique_ptr<runtime::EventTracer> event_tracer)
4372
: file_path_(file_path),
73+
data_map_path_(""),
4474
load_mode_(load_mode),
4575
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
4676
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
47-
event_tracer_(std::move(event_tracer)) {
77+
event_tracer_(std::move(event_tracer)),
78+
data_map_loader_(nullptr),
79+
data_map_(nullptr) {
4880
runtime::runtime_init();
4981
}
5082

83+
Module::Module(
84+
const std::string& file_path,
85+
const std::string& data_map_path,
86+
const LoadMode load_mode,
87+
std::unique_ptr<runtime::EventTracer> event_tracer)
88+
: file_path_(file_path),
89+
data_map_path_(data_map_path),
90+
load_mode_(load_mode),
91+
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
92+
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
93+
event_tracer_(std::move(event_tracer)),
94+
data_map_loader_(nullptr),
95+
data_map_(nullptr) {
96+
runtime::runtime_init();
97+
}
98+
99+
100+
51101
Module::Module(
52102
std::unique_ptr<runtime::DataLoader> data_loader,
53103
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
54104
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
55-
std::unique_ptr<runtime::EventTracer> event_tracer)
56-
: data_loader_(std::move(data_loader)),
105+
std::unique_ptr<runtime::EventTracer> event_tracer,
106+
std::unique_ptr<runtime::NamedDataMap> data_map)
107+
:
108+
file_path_(""),
109+
data_map_path_(""),
110+
data_loader_(std::move(data_loader)),
57111
memory_allocator_(
58112
memory_allocator ? std::move(memory_allocator)
59113
: std::make_unique<MallocMemoryAllocator>()),
60114
temp_allocator_(
61115
temp_allocator ? std::move(temp_allocator)
62116
: std::make_unique<MallocMemoryAllocator>()),
63-
event_tracer_(std::move(event_tracer)) {
117+
event_tracer_(std::move(event_tracer)),
118+
data_map_loader_(nullptr),
119+
data_map_(std::move(data_map)) {
64120
runtime::runtime_init();
65121
}
66122

67123
Module::Module(
68124
std::shared_ptr<runtime::Program> program,
69125
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
70126
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
71-
std::unique_ptr<runtime::EventTracer> event_tracer)
72-
: program_(std::move(program)),
127+
std::unique_ptr<runtime::EventTracer> event_tracer,
128+
std::unique_ptr<runtime::NamedDataMap> data_map)
129+
:
130+
file_path_(""),
131+
data_map_path_(""),
132+
program_(std::move(program)),
73133
memory_allocator_(
74134
memory_allocator ? std::move(memory_allocator)
75135
: std::make_unique<MallocMemoryAllocator>()),
76136
temp_allocator_(
77137
temp_allocator ? std::move(temp_allocator)
78138
: std::make_unique<MallocMemoryAllocator>()),
79-
event_tracer_(std::move(event_tracer)) {
139+
event_tracer_(std::move(event_tracer)),
140+
data_map_loader_(nullptr),
141+
data_map_(std::move(data_map)) {
80142
runtime::runtime_init();
81143
}
82144

83145
runtime::Error Module::load(const runtime::Program::Verification verification) {
84146
if (!is_loaded()) {
147+
// Load the program
85148
if (!data_loader_) {
86-
switch (load_mode_) {
87-
case LoadMode::File:
88-
data_loader_ =
89-
ET_UNWRAP_UNIQUE(FileDataLoader::from(file_path_.c_str()));
90-
break;
91-
case LoadMode::Mmap:
92-
data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
93-
file_path_.c_str(), MmapDataLoader::MlockConfig::NoMlock));
94-
break;
95-
case LoadMode::MmapUseMlock:
96-
data_loader_ =
97-
ET_UNWRAP_UNIQUE(MmapDataLoader::from(file_path_.c_str()));
98-
break;
99-
case LoadMode::MmapUseMlockIgnoreErrors:
100-
data_loader_ = ET_UNWRAP_UNIQUE(MmapDataLoader::from(
101-
file_path_.c_str(),
102-
MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
103-
break;
149+
auto res = load_file(file_path_, load_mode_);
150+
if (!res.ok()) {
151+
return res.error();
104152
}
105-
};
153+
data_loader_ = std::move(res.get());
154+
}
155+
// If a .ptd path was given load it.
156+
if (data_map_path_ != ""){
157+
auto res = load_file(data_map_path_, load_mode_);
158+
if (!res.ok()) {
159+
return res.error();
160+
}
161+
data_map_loader_ = std::move(res.get());
162+
}
163+
// If we have a .ptd loader, then load the map.
164+
if (data_map_loader_) {
165+
data_map_ = ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loader_.get()));
166+
}
167+
// else: either the map itself was provided or we have no data map, either way no work to do.
106168
auto program = ET_UNWRAP_UNIQUE(
107169
runtime::Program::load(data_loader_.get(), verification));
108170
program_ = std::shared_ptr<runtime::Program>(
@@ -130,6 +192,7 @@ runtime::Error Module::load_method(
130192
ET_CHECK_OK_OR_RETURN_ERROR(load());
131193

132194
MethodHolder method_holder;
195+
133196
const auto method_metadata =
134197
ET_UNWRAP(program_->method_meta(method_name.c_str()));
135198
const auto planned_buffersCount =
@@ -155,10 +218,22 @@ runtime::Error Module::load_method(
155218
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
156219
method_name.c_str(),
157220
method_holder.memory_manager.get(),
158-
event_tracer ? event_tracer : this->event_tracer()));
221+
event_tracer ? event_tracer : this->event_tracer(),
222+
data_map_.get()));
159223
method_holder.inputs.resize(method_holder.method->inputs_size());
160224
methods_.emplace(method_name, std::move(method_holder));
161225
}
226+
return runtime::Error::Ok;
227+
}
228+
229+
runtime::Error Module::load_method(
230+
const std::string& method_name,
231+
const std::string& data_map_path,
232+
torch::executor::EventTracer* event_tracer) {
233+
if (!is_method_loaded(method_name)) {
234+
ET_CHECK_OK_OR_RETURN_ERROR(load());
235+
return load_method(method_name, event_tracer);
236+
}
162237
return runtime::Error::Ok;
163238
}
164239

extension/module/module.h

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@ class Module {
5151
const LoadMode load_mode = LoadMode::MmapUseMlock,
5252
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
5353

54+
/**
55+
* Constructs an instance by loading a program from a file with specified
56+
* memory locking behavior.
57+
*
58+
* @param[in] file_path The path to the ExecuTorch program file to load.
59+
* @param[in] data_map_path The path to a .ptd file
60+
* @param[in] load_mode The loading mode to use.
61+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
62+
*/
63+
explicit Module(
64+
const std::string& file_path,
65+
const std::string& data_map_path,
66+
const LoadMode load_mode = LoadMode::MmapUseMlock,
67+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
68+
69+
5470
/**
5571
* Constructs an instance with the provided data loader and memory allocator.
5672
*
@@ -64,7 +80,8 @@ class Module {
6480
std::unique_ptr<runtime::DataLoader> data_loader,
6581
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
6682
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
67-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
83+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
84+
std::unique_ptr<runtime::NamedDataMap> data_map = nullptr);
6885

6986
/**
7087
* Constructs an instance using an existing shared program.
@@ -80,7 +97,8 @@ class Module {
8097
std::shared_ptr<runtime::Program> program,
8198
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
8299
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
83-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
100+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
101+
std::unique_ptr<runtime::NamedDataMap> data_map = nullptr);
84102

85103
Module(const Module&) = delete;
86104
Module& operator=(const Module&) = delete;
@@ -128,6 +146,26 @@ class Module {
128146
*/
129147
runtime::Result<std::unordered_set<std::string>> method_names();
130148

149+
/**
150+
* Load a specific method from the program and set up memory management if
151+
* needed. The loaded method is cached to reuse the next time it's executed.
152+
*
153+
* @param[in] method_name The name of the method to load.
154+
* @param[in] data_map_path Path to a .ptd file containing weights
155+
* for this method.
156+
* @param[in] event_tracer Per-method event tracer to profile/trace methods
157+
* individually. When not given, the event tracer passed to the Module
158+
* constructor is used. Otherwise, this per-method event tracer takes
159+
* precedence.
160+
*
161+
* @returns An Error to indicate success or failure.
162+
*/
163+
ET_NODISCARD
164+
runtime::Error load_method(
165+
const std::string& method_name,
166+
const std::string& data_map_path,
167+
torch::executor::EventTracer* event_tracer = nullptr);
168+
131169
/**
132170
* Load a specific method from the program and set up memory management if
133171
* needed. The loaded method is cached to reuse the next time it's executed.
@@ -433,14 +471,16 @@ class Module {
433471
std::vector<runtime::EValue> inputs;
434472
};
435473

436-
private:
437474
std::string file_path_;
475+
std::string data_map_path_;
438476
LoadMode load_mode_{LoadMode::MmapUseMlock};
439477
std::shared_ptr<runtime::Program> program_;
440478
std::unique_ptr<runtime::DataLoader> data_loader_;
441479
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
442480
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
443481
std::unique_ptr<runtime::EventTracer> event_tracer_;
482+
std::unique_ptr<runtime::DataLoader> data_map_loader_;
483+
std::unique_ptr<runtime::NamedDataMap> data_map_;
444484

445485
protected:
446486
std::unordered_map<std::string, MethodHolder> methods_;

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ def define_common_targets():
2828
],
2929
exported_deps = [
3030
"//executorch/runtime/executor:program" + aten_suffix,
31+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
3132
],
3233
)

extension/module/test/module_test.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@ using namespace ::executorch::runtime;
2222
class ModuleTest : public ::testing::Test {
2323
protected:
2424
static void SetUpTestSuite() {
25-
model_path_ = std::getenv("RESOURCES_PATH") + std::string("/add.pte");
25+
std::string resources_path;
26+
if (const char* env = std::getenv("RESOURCES_PATH")) {
27+
resources_path = env;
28+
}
29+
model_path_ = resources_path + "/add.pte";
30+
linear_path_ = resources_path + "/linear.pte";
31+
linear_data_path_ = resources_path + "/linear.ptd";
2632
}
2733

28-
static std::string model_path_;
34+
static inline std::string model_path_;
35+
static inline std::string linear_path_;
36+
static inline std::string linear_data_path_;
2937
};
3038

31-
std::string ModuleTest::model_path_;
32-
3339
TEST_F(ModuleTest, TestLoad) {
3440
Module module(model_path_);
3541

@@ -435,3 +441,15 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
435441

436442
EXPECT_NE(module.set_output(EValue()), Error::Ok);
437443
}
444+
445+
TEST_F(ModuleTest, TestPTD) {
446+
Module module(linear_path_, linear_data_path_);
447+
448+
ASSERT_EQ(
449+
module.load_method("forward"), Error::Ok);
450+
451+
auto tensor1 =
452+
make_tensor_ptr({3, 3}, {2.f, 3.f, 4.f, 2.f, 3.f, 4.f, 2.f, 3.f, 4.f});
453+
454+
ASSERT_EQ(module.forward(tensor1).error(), Error::Ok);
455+
}

0 commit comments

Comments
 (0)