@@ -75,9 +75,7 @@ Module::Module(
7575      load_mode_ (load_mode),
7676      memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
7777      temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
78-       event_tracer_(std::move(event_tracer)),
79-       data_map_loader_(nullptr ),
80-       data_map_(nullptr ) {
78+       event_tracer_(std::move(event_tracer)) {
8179  runtime::runtime_init ();
8280}
8381
@@ -87,13 +85,27 @@ Module::Module(
8785    const  LoadMode load_mode,
8886    std::unique_ptr<runtime::EventTracer> event_tracer)
8987    : file_path_(file_path),
90-       data_map_path_(data_map_path),
9188      load_mode_(load_mode),
9289      memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
9390      temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
94-       event_tracer_(std::move(event_tracer)),
95-       data_map_loader_(nullptr ),
96-       data_map_(nullptr ) {
91+       event_tracer_(std::move(event_tracer)) {
92+   if  (!data_map_path.empty ()) {
93+     data_files_.push_back (data_map_path);
94+   }
95+   runtime::runtime_init ();
96+ }
97+ 
98+ Module::Module (
99+     const  std::string& file_path,
100+     std::vector<std::string> data_files,
101+     const  LoadMode load_mode,
102+     std::unique_ptr<runtime::EventTracer> event_tracer)
103+     : file_path_(file_path),
104+       data_files_(std::move(data_files)),
105+       load_mode_(load_mode),
106+       memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
107+       temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
108+       event_tracer_(std::move(event_tracer)) {
97109  runtime::runtime_init ();
98110}
99111
@@ -110,9 +122,10 @@ Module::Module(
110122      temp_allocator_(
111123          temp_allocator ? std::move(temp_allocator)
112124                         : std::make_unique<MallocMemoryAllocator>()),
113-       event_tracer_(std::move(event_tracer)),
114-       data_map_loader_(std::move(data_map_loader)),
115-       data_map_(nullptr ) {
125+       event_tracer_(std::move(event_tracer)) {
126+   if  (data_map_loader) {
127+     data_map_loaders_.push_back (std::move (data_map_loader));
128+   }
116129  runtime::runtime_init ();
117130}
118131
@@ -129,9 +142,10 @@ Module::Module(
129142      temp_allocator_(
130143          temp_allocator ? std::move(temp_allocator)
131144                         : std::make_unique<MallocMemoryAllocator>()),
132-       event_tracer_(std::move(event_tracer)),
133-       data_map_loader_(std::move(data_map_loader)),
134-       data_map_(nullptr ) {
145+       event_tracer_(std::move(event_tracer)) {
146+   if  (data_map_loader) {
147+     data_map_loaders_.push_back (std::move (data_map_loader));
148+   }
135149  runtime::runtime_init ();
136150}
137151
@@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) {
140154    if  (!data_loader_) {
141155      data_loader_ = ET_UNWRAP (make_data_loader (file_path_, load_mode_));
142156    }
143-     if  (!data_map_path_.empty ()) {
144-       data_map_loader_ =
145-           ET_UNWRAP (make_data_loader (data_map_path_, load_mode_));
157+     if  (data_files_.size () > 0 ) {
158+       ET_CHECK_OR_RETURN_ERROR (
159+           data_files_.size () == 1 ,
160+           NotImplemented,
161+           " Multiple named data map paths are not supported yet."  );
162+       for  (const  auto & data_file : data_files_) {
163+         data_map_loaders_.push_back (
164+             ET_UNWRAP (make_data_loader (data_file, load_mode_)));
165+       }
146166    }
147-     if  (data_map_loader_) {
148-       data_map_ =
149-           ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loader_.get ()));
167+ 
168+     if  (data_map_loaders_.size () > 0 ) {
169+       ET_CHECK_OR_RETURN_ERROR (
170+           data_map_loaders_.size () == 1  && merged_data_map_ == nullptr ,
171+           NotImplemented,
172+           " Multiple named data map loaders are not supported yet."  );
173+       //  TODO(lfq): support multiple named data map loaders.
174+       merged_data_map_ =
175+           ET_UNWRAP_UNIQUE (FlatTensorDataMap::load (data_map_loaders_[0 ].get ()));
150176    }
177+ 
151178    auto  program =
152179        ET_UNWRAP_UNIQUE (Program::load (data_loader_.get (), verification));
153180    program_ = std::shared_ptr<Program>(
@@ -209,7 +236,7 @@ runtime::Error Module::load_method(
209236        method_name.c_str (),
210237        method_holder.memory_manager .get (),
211238        event_tracer ? event_tracer : this ->event_tracer (),
212-         data_map_ .get ()));
239+         merged_data_map_ .get ()));
213240    methods_.emplace (method_name, std::move (method_holder));
214241  }
215242  return  runtime::Error::Ok;
0 commit comments