Skip to content

Commit 4d944dc

Browse files
committed
Module support for multiple ptd files
Pull Request resolved: #14158 Support multiple PTD files in Module. Context: https://docs.google.com/document/d/19RLLdWNHQoRi8Ufz4oE-gGjOz0IShjN_NZi5jlgMBZI/edit?tab=t.0 This change updates the following private variables in Module: ``` std::string data_path --> std::unordered_set<std::string> data_files_ std::unique_ptr<DataLoader> data_map_loader --> std::vectror<std::unique_ptr<DataLoader>> data_map_loaders_ std::unique_ptr<NamedDataMap> data_map --> std::vector<std::unique_ptr<NamedDataMap> named_data_maps_ ``` And introduces a new private variable. When we have multiple NamedDataMaps, they need to be merged into one, for use in method, etc. This is not implemented yet. ``` std::unique_ptr<NamedDataMap> merged_data_map_ ``` The process of using a PTD file is: ``` std::string file --> wrapped in DataLoader --> wrapped in NamedDataMap. ``` At each stage we can have multiple. This diff also introduces a new Module constructor that takes in `std::unordered_set<std::string> named_data_map_paths_` TODO: add a MergedDataMap to extension/module that can merge all the data maps together. ghstack-source-id: 313188117 @exported-using-ghexport Differential Revision: [D82059808](https://our.internmc.facebook.com/intern/diff/D82059808/)
1 parent f7c009e commit 4d944dc

File tree

3 files changed

+82
-24
lines changed

3 files changed

+82
-24
lines changed

extension/module/module.cpp

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ Module::Module(
7575
load_mode_(load_mode),
7676
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
7777
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
78-
event_tracer_(std::move(event_tracer)),
79-
data_map_loader_(nullptr),
80-
data_map_(nullptr) {
78+
event_tracer_(std::move(event_tracer)) {
8179
runtime::runtime_init();
8280
}
8381

@@ -87,13 +85,27 @@ Module::Module(
8785
const LoadMode load_mode,
8886
std::unique_ptr<runtime::EventTracer> event_tracer)
8987
: file_path_(file_path),
90-
data_map_path_(data_map_path),
9188
load_mode_(load_mode),
9289
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
9390
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
94-
event_tracer_(std::move(event_tracer)),
95-
data_map_loader_(nullptr),
96-
data_map_(nullptr) {
91+
event_tracer_(std::move(event_tracer)) {
92+
if (!data_map_path.empty()) {
93+
data_files_.push_back(data_map_path);
94+
}
95+
runtime::runtime_init();
96+
}
97+
98+
Module::Module(
99+
const std::string& file_path,
100+
std::vector<std::string> data_files,
101+
const LoadMode load_mode,
102+
std::unique_ptr<runtime::EventTracer> event_tracer)
103+
: file_path_(file_path),
104+
data_files_(std::move(data_files)),
105+
load_mode_(load_mode),
106+
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
107+
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
108+
event_tracer_(std::move(event_tracer)) {
97109
runtime::runtime_init();
98110
}
99111

@@ -110,9 +122,10 @@ Module::Module(
110122
temp_allocator_(
111123
temp_allocator ? std::move(temp_allocator)
112124
: std::make_unique<MallocMemoryAllocator>()),
113-
event_tracer_(std::move(event_tracer)),
114-
data_map_loader_(std::move(data_map_loader)),
115-
data_map_(nullptr) {
125+
event_tracer_(std::move(event_tracer)) {
126+
if (data_map_loader) {
127+
data_map_loaders_.push_back(std::move(data_map_loader));
128+
}
116129
runtime::runtime_init();
117130
}
118131

@@ -129,9 +142,10 @@ Module::Module(
129142
temp_allocator_(
130143
temp_allocator ? std::move(temp_allocator)
131144
: std::make_unique<MallocMemoryAllocator>()),
132-
event_tracer_(std::move(event_tracer)),
133-
data_map_loader_(std::move(data_map_loader)),
134-
data_map_(nullptr) {
145+
event_tracer_(std::move(event_tracer)) {
146+
if (data_map_loader) {
147+
data_map_loaders_.push_back(std::move(data_map_loader));
148+
}
135149
runtime::runtime_init();
136150
}
137151

@@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) {
140154
if (!data_loader_) {
141155
data_loader_ = ET_UNWRAP(make_data_loader(file_path_, load_mode_));
142156
}
143-
if (!data_map_path_.empty()) {
144-
data_map_loader_ =
145-
ET_UNWRAP(make_data_loader(data_map_path_, load_mode_));
157+
if (data_files_.size() > 0) {
158+
ET_CHECK_OR_RETURN_ERROR(
159+
data_files_.size() == 1,
160+
NotImplemented,
161+
"Multiple named data map paths are not supported yet.");
162+
for (const auto& data_file : data_files_) {
163+
data_map_loaders_.push_back(
164+
ET_UNWRAP(make_data_loader(data_file, load_mode_)));
165+
}
146166
}
147-
if (data_map_loader_) {
148-
data_map_ =
149-
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loader_.get()));
167+
168+
if (data_map_loaders_.size() > 0) {
169+
ET_CHECK_OR_RETURN_ERROR(
170+
data_map_loaders_.size() == 1 && merged_data_map_ == nullptr,
171+
NotImplemented,
172+
"Multiple named data map loaders are not supported yet.");
173+
// TODO(lfq): support multiple named data map loaders.
174+
merged_data_map_ =
175+
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loaders_[0].get()));
150176
}
177+
151178
auto program =
152179
ET_UNWRAP_UNIQUE(Program::load(data_loader_.get(), verification));
153180
program_ = std::shared_ptr<Program>(
@@ -209,7 +236,7 @@ runtime::Error Module::load_method(
209236
method_name.c_str(),
210237
method_holder.memory_manager.get(),
211238
event_tracer ? event_tracer : this->event_tracer(),
212-
data_map_.get()));
239+
merged_data_map_.get()));
213240
methods_.emplace(method_name, std::move(method_holder));
214241
}
215242
return runtime::Error::Ok;

extension/module/module.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Module {
7070
* memory locking behavior.
7171
*
7272
* @param[in] file_path The path to the ExecuTorch program file to load.
73-
* @param[in] data_map_path The path to a .ptd file
73+
* @param[in] data_map_path The path to a .ptd file.
7474
* @param[in] load_mode The loading mode to use.
7575
* @param[in] event_tracer A EventTracer used for tracking and logging events.
7676
*/
@@ -80,6 +80,21 @@ class Module {
8080
const LoadMode load_mode = LoadMode::File,
8181
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
8282

83+
/**
84+
* Constructs an instance by loading a program from a file with specified
85+
* memory locking behavior.
86+
*
87+
* @param[in] file_path The path to the ExecuTorch program file to load.
88+
* @param[in] data_files The path to one or more .ptd file/s.
89+
* @param[in] load_mode The loading mode to use.
90+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
91+
*/
92+
explicit Module(
93+
const std::string& file_path,
94+
std::vector<std::string> data_files,
95+
const LoadMode load_mode = LoadMode::File,
96+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
97+
8398
/**
8499
* Constructs an instance with the provided data loader and memory allocator.
85100
*
@@ -614,15 +629,16 @@ class Module {
614629
};
615630

616631
std::string file_path_;
617-
std::string data_map_path_;
632+
std::vector<std::string> data_files_;
618633
LoadMode load_mode_{LoadMode::File};
619634
std::shared_ptr<Program> program_;
620635
std::unique_ptr<runtime::DataLoader> data_loader_;
621636
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
622637
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
623638
std::unique_ptr<runtime::EventTracer> event_tracer_;
624-
std::unique_ptr<runtime::DataLoader> data_map_loader_;
625-
std::unique_ptr<NamedDataMap> data_map_;
639+
std::vector<std::unique_ptr<runtime::DataLoader>> data_map_loaders_;
640+
std::vector<std::unique_ptr<NamedDataMap>> named_data_maps_;
641+
std::unique_ptr<NamedDataMap> merged_data_map_;
626642
ET_DEPRECATED std::vector<uint8_t> debug_buffer_;
627643

628644
protected:

extension/module/test/module_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,18 @@ TEST_F(ModuleTest, TestPTD) {
530530
auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
531531
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);
532532
}
533+
534+
TEST_F(ModuleTest, TestPTD_Multiple) {
535+
std::vector<std::string> data_files = {add_mul_data_path_};
536+
Module module(add_mul_path_, data_files);
537+
538+
ASSERT_EQ(module.load_method("forward"), Error::Ok);
539+
540+
auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
541+
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);
542+
543+
// Confirm that the data_file is not std::move'd away.
544+
ASSERT_EQ(std::strcmp(data_files[0].c_str(), add_mul_data_path_.c_str()), 0);
545+
546+
// TODO(lfq): add test when merge capability is supported.
547+
}

0 commit comments

Comments
 (0)