@@ -75,9 +75,7 @@ Module::Module(
7575 load_mode_ (load_mode),
7676 memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
7777 temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
78- event_tracer_(std::move(event_tracer)),
79- data_map_loader_(nullptr ),
80- data_map_(nullptr ) {
78+ event_tracer_(std::move(event_tracer)) {
8179 runtime::runtime_init ();
8280}
8381
@@ -87,13 +85,27 @@ Module::Module(
8785 const LoadMode load_mode,
8886 std::unique_ptr<runtime::EventTracer> event_tracer)
8987 : file_path_(file_path),
90- data_map_path_(data_map_path),
9188 load_mode_(load_mode),
9289 memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
9390 temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
94- event_tracer_(std::move(event_tracer)),
95- data_map_loader_(nullptr ),
96- data_map_(nullptr ) {
91+ event_tracer_(std::move(event_tracer)) {
92+ if (!data_map_path.empty ()) {
93+ data_files_.push_back (data_map_path);
94+ }
95+ runtime::runtime_init ();
96+ }
97+
98+ Module::Module (
99+ const std::string& file_path,
100+ std::vector<std::string> data_files,
101+ const LoadMode load_mode,
102+ std::unique_ptr<runtime::EventTracer> event_tracer)
103+ : file_path_(file_path),
104+ data_files_(std::move(data_files)),
105+ load_mode_(load_mode),
106+ memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
107+ temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
108+ event_tracer_(std::move(event_tracer)) {
97109 runtime::runtime_init ();
98110}
99111
@@ -110,9 +122,10 @@ Module::Module(
110122 temp_allocator_(
111123 temp_allocator ? std::move(temp_allocator)
112124 : std::make_unique<MallocMemoryAllocator>()),
113- event_tracer_(std::move(event_tracer)),
114- data_map_loader_(std::move(data_map_loader)),
115- data_map_(nullptr ) {
125+ event_tracer_(std::move(event_tracer)) {
126+ if (data_map_loader) {
127+ data_map_loaders_.push_back (std::move (data_map_loader));
128+ }
116129 runtime::runtime_init ();
117130}
118131
@@ -129,9 +142,10 @@ Module::Module(
129142 temp_allocator_(
130143 temp_allocator ? std::move(temp_allocator)
131144 : std::make_unique<MallocMemoryAllocator>()),
132- event_tracer_(std::move(event_tracer)),
133- data_map_loader_(std::move(data_map_loader)),
134- data_map_(nullptr ) {
145+ event_tracer_(std::move(event_tracer)) {
146+ if (data_map_loader) {
147+ data_map_loaders_.push_back (std::move (data_map_loader));
148+ }
135149 runtime::runtime_init ();
136150}
137151
@@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) {
140154 if (!data_loader_) {
141155 data_loader_ = ET_UNWRAP (make_data_loader (file_path_, load_mode_));
142156 }
143- if (!data_map_path_.empty ()) {
144- data_map_loader_ =
145- ET_UNWRAP (make_data_loader (data_map_path_, load_mode_));
157+ if (data_files_.size () > 0 ) {
158+ ET_CHECK_OR_RETURN_ERROR (
159+ data_files_.size () == 1 ,
160+ NotImplemented,
161+ " Multiple named data map paths are not supported yet." );
162+ for (const auto & data_file : data_files_) {
163+ data_map_loaders_.push_back (
164+ ET_UNWRAP (make_data_loader (data_file, load_mode_)));
165+ }
146166 }
147- if (data_map_loader_) {
148- data_map_ =
149- ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
167+
168+ if (data_map_loaders_.size () > 0 ) {
169+ ET_CHECK_OR_RETURN_ERROR (
170+ data_map_loaders_.size () == 1 && merged_data_map_ == nullptr ,
171+ NotImplemented,
172+ " Multiple named data map loaders are not supported yet." );
173+ // TODO(lfq): support multiple named data map loaders.
174+ merged_data_map_ =
175+ ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loaders_[0 ].get ()));
150176 }
177+
151178 auto program =
152179 ET_UNWRAP_UNIQUE (Program::load (data_loader_.get (), verification));
153180 program_ = std::shared_ptr<Program>(
@@ -209,7 +236,7 @@ runtime::Error Module::load_method(
209236 method_name.c_str (),
210237 method_holder.memory_manager .get (),
211238 event_tracer ? event_tracer : this ->event_tracer (),
212- data_map_ .get ()));
239+ merged_data_map_ .get ()));
213240 methods_.emplace (method_name, std::move (method_holder));
214241 }
215242 return runtime::Error::Ok;
0 commit comments