@@ -75,9 +75,7 @@ Module::Module(
75
75
load_mode_ (load_mode),
76
76
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
77
77
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
78
- event_tracer_(std::move(event_tracer)),
79
- data_map_loader_(nullptr ),
80
- data_map_(nullptr ) {
78
+ event_tracer_(std::move(event_tracer)) {
81
79
runtime::runtime_init ();
82
80
}
83
81
@@ -87,13 +85,27 @@ Module::Module(
87
85
const LoadMode load_mode,
88
86
std::unique_ptr<runtime::EventTracer> event_tracer)
89
87
: file_path_(file_path),
90
- data_map_path_(data_map_path),
91
88
load_mode_(load_mode),
92
89
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
93
90
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
94
- event_tracer_(std::move(event_tracer)),
95
- data_map_loader_(nullptr ),
96
- data_map_(nullptr ) {
91
+ event_tracer_(std::move(event_tracer)) {
92
+ if (!data_map_path.empty ()) {
93
+ data_files_.push_back (data_map_path);
94
+ }
95
+ runtime::runtime_init ();
96
+ }
97
+
98
+ Module::Module (
99
+ const std::string& file_path,
100
+ std::vector<std::string> data_files,
101
+ const LoadMode load_mode,
102
+ std::unique_ptr<runtime::EventTracer> event_tracer)
103
+ : file_path_(file_path),
104
+ data_files_(std::move(data_files)),
105
+ load_mode_(load_mode),
106
+ memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
107
+ temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
108
+ event_tracer_(std::move(event_tracer)) {
97
109
runtime::runtime_init ();
98
110
}
99
111
@@ -110,9 +122,10 @@ Module::Module(
110
122
temp_allocator_(
111
123
temp_allocator ? std::move(temp_allocator)
112
124
: std::make_unique<MallocMemoryAllocator>()),
113
- event_tracer_(std::move(event_tracer)),
114
- data_map_loader_(std::move(data_map_loader)),
115
- data_map_(nullptr ) {
125
+ event_tracer_(std::move(event_tracer)) {
126
+ if (data_map_loader) {
127
+ data_map_loaders_.push_back (std::move (data_map_loader));
128
+ }
116
129
runtime::runtime_init ();
117
130
}
118
131
@@ -129,9 +142,10 @@ Module::Module(
129
142
temp_allocator_(
130
143
temp_allocator ? std::move(temp_allocator)
131
144
: std::make_unique<MallocMemoryAllocator>()),
132
- event_tracer_(std::move(event_tracer)),
133
- data_map_loader_(std::move(data_map_loader)),
134
- data_map_(nullptr ) {
145
+ event_tracer_(std::move(event_tracer)) {
146
+ if (data_map_loader) {
147
+ data_map_loaders_.push_back (std::move (data_map_loader));
148
+ }
135
149
runtime::runtime_init ();
136
150
}
137
151
@@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) {
140
154
if (!data_loader_) {
141
155
data_loader_ = ET_UNWRAP (make_data_loader (file_path_, load_mode_));
142
156
}
143
- if (!data_map_path_.empty ()) {
144
- data_map_loader_ =
145
- ET_UNWRAP (make_data_loader (data_map_path_, load_mode_));
157
+ if (data_files_.size () > 0 ) {
158
+ ET_CHECK_OR_RETURN_ERROR (
159
+ data_files_.size () == 1 ,
160
+ NotImplemented,
161
+ " Multiple named data map paths are not supported yet." );
162
+ for (const auto & data_file : data_files_) {
163
+ data_map_loaders_.push_back (
164
+ ET_UNWRAP (make_data_loader (data_file, load_mode_)));
165
+ }
146
166
}
147
- if (data_map_loader_) {
148
- data_map_ =
149
- ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
167
+
168
+ if (data_map_loaders_.size () > 0 ) {
169
+ ET_CHECK_OR_RETURN_ERROR (
170
+ data_map_loaders_.size () == 1 && merged_data_map_ == nullptr ,
171
+ NotImplemented,
172
+ " Multiple named data map loaders are not supported yet." );
173
+ // TODO(lfq): support multiple named data map loaders.
174
+ merged_data_map_ =
175
+ ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loaders_[0 ].get ()));
150
176
}
177
+
151
178
auto program =
152
179
ET_UNWRAP_UNIQUE (Program::load (data_loader_.get (), verification));
153
180
program_ = std::shared_ptr<Program>(
@@ -209,7 +236,7 @@ runtime::Error Module::load_method(
209
236
method_name.c_str (),
210
237
method_holder.memory_manager .get (),
211
238
event_tracer ? event_tracer : this ->event_tracer (),
212
- data_map_ .get ()));
239
+ merged_data_map_ .get ()));
213
240
methods_.emplace (method_name, std::move (method_holder));
214
241
}
215
242
return runtime::Error::Ok;
0 commit comments