Skip to content

Commit 59bb3df

Browse files
committed
Introduce MergedDataMap
Differential Revision: [D76529405](https://our.internmc.facebook.com/intern/diff/D76529405/) [ghstack-poisoned]
1 parent 17e3693 commit 59bb3df

File tree

5 files changed

+381
-2
lines changed

5 files changed

+381
-2
lines changed

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test {
3333
// first.
3434
executorch::runtime::runtime_init();
3535

36-
// Load data map. The eager linear model is defined at:
37-
// //executorch/test/models/linear_model.py
36+
// Load data map. The eager addmul model is defined at:
37+
// //executorch/test/models/export_program.py
3838
const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
3939
Result<FileDataLoader> loader = FileDataLoader::from(path);
4040
ASSERT_EQ(loader.error(), Error::Ok);

runtime/executor/merged_data_map.h

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
namespace executorch {
14+
namespace runtime {
15+
/**
16+
* A NamedDataMap implementation that wraps other NamedDataMaps.
17+
*/
18+
template <size_t N>
19+
class MergedDataMap final
20+
: public executorch::ET_RUNTIME_NAMESPACE::NamedDataMap {
21+
public:
22+
/**
23+
* Creates a new NamedDataMap that takes in other data maps.
24+
*
25+
* @param[in] data_maps Array of NamedDataMap pointers to merge.
26+
* Note: the data maps must outlive the MergedDataMap instance.
27+
*/
28+
static executorch::runtime::Result<MergedDataMap> load(
29+
const std::array<const NamedDataMap*, N>& data_maps) {
30+
std::array<const NamedDataMap*, N> valid_data_maps;
31+
size_t num_data_maps = 0;
32+
for (size_t i = 0; i < data_maps.size(); i++) {
33+
if (data_maps[i] != nullptr) {
34+
valid_data_maps[num_data_maps++] = data_maps[i];
35+
}
36+
}
37+
ET_CHECK_OR_RETURN_ERROR(
38+
num_data_maps > 0, InvalidArgument, "All provided data maps are null");
39+
40+
// Check for duplicate keys.
41+
for (size_t i = 0; i < num_data_maps; i++) {
42+
for (size_t j = i + 1; j < num_data_maps; j++) {
43+
for (int k = 0; k < valid_data_maps[i]->get_num_keys().get(); k++) {
44+
const auto key = valid_data_maps[i]->get_key(k).get();
45+
ET_CHECK_OR_RETURN_ERROR(
46+
valid_data_maps[j]->get_tensor_layout(key).error() ==
47+
executorch::runtime::Error::NotFound,
48+
InvalidArgument,
49+
"Duplicate key %s in data maps at index %zu and %zu",
50+
key,
51+
i,
52+
j);
53+
}
54+
}
55+
}
56+
return MergedDataMap<N>(std::move(valid_data_maps), num_data_maps);
57+
}
58+
59+
/**
60+
* Retrieve the tensor_layout for the specified key.
61+
*
62+
* @param[in] key The name of the tensor to get metadata on.
63+
*
64+
* @return Error::NotFound if the key is not present.
65+
*/
66+
ET_NODISCARD
67+
executorch::runtime::Result<
68+
const executorch::ET_RUNTIME_NAMESPACE::TensorLayout>
69+
get_tensor_layout(executorch::aten::string_view key) const override {
70+
for (size_t i = 0; i < num_data_maps_; i++) {
71+
auto layout = data_maps_[i]->get_tensor_layout(key);
72+
if (layout.ok()) {
73+
return layout.get();
74+
}
75+
if (layout.error() != executorch::runtime::Error::NotFound) {
76+
return layout.error();
77+
}
78+
}
79+
return executorch::runtime::Error::NotFound;
80+
}
81+
82+
/**
83+
* Retrieve read-only data for the specified key.
84+
*
85+
* @param[in] key The name of the tensor to get data on.
86+
*
87+
* @return error if the key is not present or data cannot be loaded.
88+
*/
89+
ET_NODISCARD
90+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
91+
executorch::aten::string_view key) const override {
92+
for (size_t i = 0; i < num_data_maps_; i++) {
93+
auto data = data_maps_[i]->get_data(key);
94+
if (data.error() != executorch::runtime::Error::NotFound) {
95+
return data;
96+
}
97+
}
98+
return executorch::runtime::Error::NotFound;
99+
}
100+
101+
/**
102+
* Loads the data of the specified tensor into the provided buffer.
103+
*
104+
* @param[in] key The name of the tensor to get the data of.
105+
* @param[in] buffer The buffer to load data into. Must point to at least
106+
* `size` bytes of memory.
107+
* @param[in] size The number of bytes to load.
108+
*
109+
* @returns an Error indicating if the load was successful.
110+
*/
111+
ET_NODISCARD executorch::runtime::Error load_data_into(
112+
executorch::aten::string_view key,
113+
void* buffer,
114+
size_t size) const override {
115+
for (size_t i = 0; i < num_data_maps_; i++) {
116+
auto error = data_maps_[i]->load_data_into(key, buffer, size);
117+
if (error != executorch::runtime::Error::NotFound) {
118+
return error;
119+
}
120+
}
121+
return executorch::runtime::Error::NotFound;
122+
}
123+
124+
/**
125+
* @returns The number of keys in the map.
126+
*/
127+
ET_NODISCARD executorch::runtime::Result<uint32_t> get_num_keys()
128+
const override {
129+
uint32_t num_keys = 0;
130+
for (size_t i = 0; i < num_data_maps_; i++) {
131+
num_keys += data_maps_[i]->get_num_keys().get();
132+
}
133+
return num_keys;
134+
}
135+
136+
/**
137+
* @returns The key at the specified index, error if index out of bounds.
138+
*/
139+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
140+
uint32_t index) const override {
141+
uint32_t total_num_keys = get_num_keys().get();
142+
ET_CHECK_OR_RETURN_ERROR(
143+
index >= 0 && index < total_num_keys,
144+
InvalidArgument,
145+
"Index %u out of range of size %u",
146+
index,
147+
total_num_keys);
148+
for (size_t i = 0; i < num_data_maps_; i++) {
149+
auto num_keys = data_maps_[i]->get_num_keys().get();
150+
if (index < num_keys) {
151+
return data_maps_[i]->get_key(index);
152+
}
153+
index -= num_keys;
154+
}
155+
// Shouldn't reach here.
156+
return executorch::runtime::Error::Internal;
157+
}
158+
159+
MergedDataMap(MergedDataMap&&) noexcept = default;
160+
161+
~MergedDataMap() override = default;
162+
163+
private:
164+
MergedDataMap(
165+
const std::array<const NamedDataMap*, N>& data_maps,
166+
size_t num_data_maps)
167+
: data_maps_(data_maps), num_data_maps_(num_data_maps){};
168+
169+
// Not copyable or assignable.
170+
MergedDataMap(const MergedDataMap& rhs) = delete;
171+
MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete;
172+
MergedDataMap& operator=(const MergedDataMap& rhs) = delete;
173+
174+
const std::array<const NamedDataMap*, N> data_maps_;
175+
const size_t num_data_maps_;
176+
};
177+
178+
} // namespace runtime
179+
} // namespace executorch

runtime/executor/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ def define_common_targets():
6969
exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"],
7070
)
7171

72+
runtime.cxx_library(
73+
name = "merged_data_map" + aten_suffix,
74+
exported_headers = [
75+
"merged_data_map.h",
76+
],
77+
exported_deps = [
78+
"//executorch/runtime/core:named_data_map" + aten_suffix,
79+
],
80+
)
81+
7282
runtime.cxx_library(
7383
name = "program" + aten_suffix,
7484
exported_deps = [
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/executor/merged_data_map.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using namespace ::testing;
19+
using executorch::extension::FileDataLoader;
20+
using executorch::extension::FlatTensorDataMap;
21+
using executorch::runtime::DataLoader;
22+
using executorch::runtime::Error;
23+
using executorch::runtime::FreeableBuffer;
24+
using executorch::runtime::MergedDataMap;
25+
using executorch::runtime::NamedDataMap;
26+
using executorch::runtime::Result;
27+
using executorch::runtime::TensorLayout;
28+
29+
class MergedDataMapTest : public ::testing::Test {
30+
protected:
31+
void load_flat_tensor_data_map(const char* path, const char* module_name) {
32+
Result<FileDataLoader> loader = FileDataLoader::from(path);
33+
ASSERT_EQ(loader.error(), Error::Ok);
34+
loaders_.insert(
35+
{module_name,
36+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
37+
38+
Result<FlatTensorDataMap> data_map =
39+
FlatTensorDataMap::load(loaders_[module_name].get());
40+
EXPECT_EQ(data_map.error(), Error::Ok);
41+
42+
data_maps_.insert(
43+
{module_name,
44+
std::make_unique<FlatTensorDataMap>(std::move(data_map.get()))});
45+
}
46+
47+
void SetUp() override {
48+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
49+
// first.
50+
executorch::runtime::runtime_init();
51+
52+
// Load FlatTensor data maps.
53+
// The eager addmul and linear models are defined at:
54+
// //executorch/test/models/export_program.py
55+
load_flat_tensor_data_map(
56+
std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul");
57+
load_flat_tensor_data_map(
58+
std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear");
59+
}
60+
61+
private:
62+
// Must outlive data_maps_, but tests shouldn't need to touch it.
63+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
64+
65+
protected:
66+
std::unordered_map<std::string, std::unique_ptr<NamedDataMap>> data_maps_;
67+
};
68+
69+
// Check that two tensor layouts are equivalent.
70+
void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) {
71+
EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type());
72+
EXPECT_EQ(layout1.nbytes(), layout2.nbytes());
73+
EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size());
74+
for (size_t i = 0; i < layout1.sizes().size(); i++) {
75+
EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]);
76+
}
77+
EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size());
78+
for (size_t i = 0; i < layout1.dim_order().size(); i++) {
79+
EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]);
80+
}
81+
}
82+
83+
// Given that ndm is part of merged, check that all the API calls on ndm produce
84+
// the same results as merged.
85+
void compare_ndm_api_calls(
86+
const NamedDataMap* ndm,
87+
const NamedDataMap* merged) {
88+
size_t num_keys = ndm->get_num_keys().get();
89+
for (size_t i = 0; i < num_keys; i++) {
90+
auto key = ndm->get_key(i).get();
91+
92+
// Compare get_tensor_layout.
93+
auto ndm_meta = ndm->get_tensor_layout(key).get();
94+
auto merged_meta = merged->get_tensor_layout(key).get();
95+
check_tensor_layout(ndm_meta, merged_meta);
96+
97+
// Coompare get_data.
98+
auto ndm_data = ndm->get_data(key);
99+
auto merged_data = merged->get_data(key);
100+
EXPECT_EQ(ndm_data.get().size(), merged_data.get().size());
101+
for (size_t i = 0; i < ndm_meta.nbytes(); i++) {
102+
EXPECT_EQ(
103+
((uint8_t*)ndm_data.get().data())[i],
104+
((uint8_t*)merged_data.get().data())[i]);
105+
}
106+
ndm_data->Free();
107+
merged_data->Free();
108+
109+
// Compare load_data_into.
110+
void* ndm_load_into = malloc(ndm_meta.nbytes());
111+
ASSERT_EQ(
112+
Error::Ok, ndm->load_data_into(key, ndm_load_into, ndm_meta.nbytes()));
113+
114+
void* merged_load_into = malloc(merged_meta.nbytes());
115+
ASSERT_EQ(
116+
Error::Ok,
117+
merged->load_data_into(key, merged_load_into, merged_meta.nbytes()));
118+
119+
for (size_t i = 0; i < ndm_meta.nbytes(); i++) {
120+
EXPECT_EQ(((uint8_t*)ndm_load_into)[i], ((uint8_t*)merged_load_into)[i]);
121+
}
122+
free(ndm_load_into);
123+
free(merged_load_into);
124+
}
125+
}
126+
127+
TEST_F(MergedDataMapTest, LoadSingleDataMap) {
128+
const std::array<const NamedDataMap*, 1> data_map = {
129+
data_maps_["addmul"].get()};
130+
Result<MergedDataMap<1>> merged_map = MergedDataMap<1>::load(data_map);
131+
EXPECT_EQ(merged_map.error(), Error::Ok);
132+
133+
// Load one data map into a merged one with storage for up to 5 data maps.
134+
const std::array<const NamedDataMap*, 5> data_maps = {
135+
data_maps_["addmul"].get(), nullptr, nullptr, nullptr, nullptr};
136+
Result<MergedDataMap<5>> merged_map2 = MergedDataMap<5>::load(data_maps);
137+
EXPECT_EQ(merged_map2.error(), Error::Ok);
138+
}
139+
140+
TEST_F(MergedDataMapTest, LoadNullDataMap) {
141+
const std::array<const NamedDataMap*, 2> data_maps = {nullptr, nullptr};
142+
Result<MergedDataMap<2>> merged_map = MergedDataMap<2>::load(data_maps);
143+
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
144+
}
145+
146+
TEST_F(MergedDataMapTest, LoadMultipleDataMaps) {
147+
// Add pte data map here.
148+
const std::array<const NamedDataMap*, 2> data_maps = {
149+
data_maps_["addmul"].get(), data_maps_["linear"].get()};
150+
Result<MergedDataMap<2>> merged_map = MergedDataMap<2>::load(data_maps);
151+
EXPECT_EQ(merged_map.error(), Error::Ok);
152+
}
153+
154+
TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) {
155+
const std::array<const NamedDataMap*, 2> data_maps = {
156+
data_maps_["addmul"].get(), data_maps_["addmul"].get()};
157+
Result<MergedDataMap<2>> merged_map = MergedDataMap<2>::load(data_maps);
158+
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
159+
}
160+
161+
TEST_F(MergedDataMapTest, CheckDataMapContents) {
162+
const std::array<const NamedDataMap*, 2> data_maps = {
163+
data_maps_["addmul"].get(), data_maps_["linear"].get()};
164+
Result<MergedDataMap<2>> merged_map = MergedDataMap<2>::load(data_maps);
165+
EXPECT_EQ(merged_map.error(), Error::Ok);
166+
167+
// Num keys.
168+
size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get();
169+
size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get();
170+
EXPECT_EQ(
171+
merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys);
172+
173+
// API calls produce equivalent results.
174+
compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get());
175+
compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get());
176+
}

0 commit comments

Comments
 (0)