Skip to content

Commit e13b8e4

Browse files
pytorchbotlucylq
andauthored
Introduce MergedDataMap (#12304)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12087 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/87/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/87/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/87/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent e1ac7ea commit e13b8e4

File tree

5 files changed

+323
-2
lines changed

5 files changed

+323
-2
lines changed

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test {
3333
// first.
3434
executorch::runtime::runtime_init();
3535

36-
// Load data map. The eager linear model is defined at:
37-
// //executorch/test/models/linear_model.py
36+
// Load data map. The eager addmul model is defined at:
37+
// //executorch/test/models/export_program.py
3838
const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
3939
Result<FileDataLoader> loader = FileDataLoader::from(path);
4040
ASSERT_EQ(loader.error(), Error::Ok);

runtime/executor/merged_data_map.h

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
namespace executorch {
14+
namespace ET_RUNTIME_NAMESPACE {
15+
namespace internal {
16+
17+
/**
18+
* A NamedDataMap implementation that wraps other NamedDataMaps.
19+
*/
20+
class MergedDataMap final : public NamedDataMap {
21+
public:
22+
/**
23+
* Creates a new NamedDataMap that wraps two other data maps.
24+
*
25+
* @param[in] first The first NamedDataMap to merge.
26+
* @param[in] second The second NamedDataMap to merge.
27+
* Note: the data maps must outlive the MergedDataMap instance.
28+
*/
29+
static Result<MergedDataMap> load(
30+
const NamedDataMap* first,
31+
const NamedDataMap* second) {
32+
ET_CHECK_OR_RETURN_ERROR(
33+
first != nullptr && second != nullptr,
34+
InvalidArgument,
35+
"Input data map is null.");
36+
37+
// Check for duplicate keys.
38+
for (uint32_t k = 0; k < first->get_num_keys().get(); k++) {
39+
const auto key = first->get_key(k).get();
40+
ET_CHECK_OR_RETURN_ERROR(
41+
second->get_tensor_layout(key).error() == Error::NotFound,
42+
InvalidArgument,
43+
"Duplicate key %s.",
44+
key);
45+
}
46+
return MergedDataMap(first, second);
47+
}
48+
49+
/**
50+
* Retrieve the tensor_layout for the specified key.
51+
*
52+
* @param[in] key The name of the tensor to get metadata on.
53+
*
54+
* @return Error::NotFound if the key is not present.
55+
*/
56+
ET_NODISCARD
57+
Result<const TensorLayout> get_tensor_layout(
58+
executorch::aten::string_view key) const override {
59+
auto layout = first_->get_tensor_layout(key);
60+
if (layout.ok()) {
61+
return layout.get();
62+
}
63+
if (layout.error() != Error::NotFound) {
64+
return layout.error();
65+
}
66+
return second_->get_tensor_layout(key);
67+
}
68+
69+
/**
70+
* Retrieve read-only data for the specified key.
71+
*
72+
* @param[in] key The name of the tensor to get data on.
73+
*
74+
* @return error if the key is not present or data cannot be loaded.
75+
*/
76+
ET_NODISCARD
77+
Result<FreeableBuffer> get_data(
78+
executorch::aten::string_view key) const override {
79+
auto data = first_->get_data(key);
80+
if (data.error() != Error::NotFound) {
81+
return data;
82+
}
83+
return second_->get_data(key);
84+
}
85+
86+
/**
87+
* Loads the data of the specified tensor into the provided buffer.
88+
* Not used in the MergedDataMap.
89+
*
90+
* @param[in] key The name of the tensor to get the data of.
91+
* @param[in] buffer The buffer to load data into. Must point to at least
92+
* `size` bytes of memory.
93+
* @param[in] size The number of bytes to load.
94+
*
95+
* @returns an Error indicating if the load was successful.
96+
*/
97+
ET_NODISCARD Error load_data_into(
98+
ET_UNUSED executorch::aten::string_view key,
99+
ET_UNUSED void* buffer,
100+
ET_UNUSED size_t size) const override {
101+
return Error::NotImplemented;
102+
}
103+
104+
/**
105+
* @returns The number of keys in the map.
106+
*/
107+
ET_NODISCARD Result<uint32_t> get_num_keys() const override {
108+
return first_->get_num_keys().get() + second_->get_num_keys().get();
109+
}
110+
111+
/**
112+
* @returns The key at the specified index, error if index out of bounds.
113+
*/
114+
ET_NODISCARD Result<const char*> get_key(uint32_t index) const override {
115+
uint32_t total_num_keys = get_num_keys().get();
116+
ET_CHECK_OR_RETURN_ERROR(
117+
index >= 0 && index < total_num_keys,
118+
InvalidArgument,
119+
"Index %u out of range of size %u",
120+
index,
121+
total_num_keys);
122+
123+
if (index < first_->get_num_keys().get()) {
124+
return first_->get_key(index);
125+
} else {
126+
return second_->get_key(index - first_->get_num_keys().get());
127+
}
128+
}
129+
130+
MergedDataMap(MergedDataMap&&) noexcept = default;
131+
132+
~MergedDataMap() override = default;
133+
134+
private:
135+
MergedDataMap(const NamedDataMap* first, const NamedDataMap* second)
136+
: first_{first}, second_{second} {}
137+
138+
// Not copyable or assignable.
139+
MergedDataMap(const MergedDataMap& rhs) = delete;
140+
MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete;
141+
MergedDataMap& operator=(const MergedDataMap& rhs) = delete;
142+
143+
const NamedDataMap* first_;
144+
const NamedDataMap* second_;
145+
};
146+
147+
} // namespace internal
148+
} // namespace ET_RUNTIME_NAMESPACE
149+
} // namespace executorch

runtime/executor/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ def define_common_targets():
6969
exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"],
7070
)
7171

72+
runtime.cxx_library(
73+
name = "merged_data_map" + aten_suffix,
74+
exported_headers = [
75+
"merged_data_map.h",
76+
],
77+
exported_deps = [
78+
"//executorch/runtime/core:named_data_map" + aten_suffix,
79+
],
80+
)
81+
7282
runtime.cxx_library(
7383
name = "program" + aten_suffix,
7484
exported_deps = [
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/executor/merged_data_map.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using namespace ::testing;
19+
using executorch::extension::FileDataLoader;
20+
using executorch::extension::FlatTensorDataMap;
21+
using executorch::runtime::DataLoader;
22+
using executorch::runtime::Error;
23+
using executorch::runtime::FreeableBuffer;
24+
using executorch::runtime::NamedDataMap;
25+
using executorch::runtime::Result;
26+
using executorch::runtime::TensorLayout;
27+
using executorch::runtime::internal::MergedDataMap;
28+
29+
class MergedDataMapTest : public ::testing::Test {
30+
protected:
31+
void load_flat_tensor_data_map(const char* path, const char* module_name) {
32+
Result<FileDataLoader> loader = FileDataLoader::from(path);
33+
ASSERT_EQ(loader.error(), Error::Ok);
34+
loaders_.insert(
35+
{module_name,
36+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
37+
38+
Result<FlatTensorDataMap> data_map =
39+
FlatTensorDataMap::load(loaders_[module_name].get());
40+
EXPECT_EQ(data_map.error(), Error::Ok);
41+
42+
data_maps_.insert(
43+
{module_name,
44+
std::make_unique<FlatTensorDataMap>(std::move(data_map.get()))});
45+
}
46+
47+
void SetUp() override {
48+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
49+
// first.
50+
executorch::runtime::runtime_init();
51+
52+
// Load FlatTensor data maps.
53+
// The eager addmul and linear models are defined at:
54+
// //executorch/test/models/export_program.py
55+
load_flat_tensor_data_map(
56+
std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul");
57+
load_flat_tensor_data_map(
58+
std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear");
59+
}
60+
61+
private:
62+
// Must outlive data_maps_, but tests shouldn't need to touch it.
63+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
64+
65+
protected:
66+
std::unordered_map<std::string, std::unique_ptr<NamedDataMap>> data_maps_;
67+
};
68+
69+
// Check that two tensor layouts are equivalent.
70+
void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) {
71+
EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type());
72+
EXPECT_EQ(layout1.nbytes(), layout2.nbytes());
73+
EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size());
74+
for (size_t i = 0; i < layout1.sizes().size(); i++) {
75+
EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]);
76+
}
77+
EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size());
78+
for (size_t i = 0; i < layout1.dim_order().size(); i++) {
79+
EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]);
80+
}
81+
}
82+
83+
// Given that ndm is part of merged, check that all the API calls on ndm produce
84+
// the same results as merged.
85+
void compare_ndm_api_calls(
86+
const NamedDataMap* ndm,
87+
const NamedDataMap* merged) {
88+
uint32_t num_keys = ndm->get_num_keys().get();
89+
for (uint32_t i = 0; i < num_keys; i++) {
90+
auto key = ndm->get_key(i).get();
91+
92+
// Compare get_tensor_layout.
93+
auto ndm_meta = ndm->get_tensor_layout(key).get();
94+
auto merged_meta = merged->get_tensor_layout(key).get();
95+
check_tensor_layout(ndm_meta, merged_meta);
96+
97+
// Coompare get_data.
98+
auto ndm_data = ndm->get_data(key);
99+
auto merged_data = merged->get_data(key);
100+
EXPECT_EQ(ndm_data.get().size(), merged_data.get().size());
101+
for (size_t j = 0; j < ndm_meta.nbytes(); j++) {
102+
EXPECT_EQ(
103+
((uint8_t*)ndm_data.get().data())[j],
104+
((uint8_t*)merged_data.get().data())[j]);
105+
}
106+
ndm_data->Free();
107+
merged_data->Free();
108+
}
109+
}
110+
111+
TEST_F(MergedDataMapTest, LoadNullDataMap) {
112+
Result<MergedDataMap> merged_map = MergedDataMap::load(nullptr, nullptr);
113+
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
114+
}
115+
116+
TEST_F(MergedDataMapTest, LoadMultipleDataMaps) {
117+
Result<MergedDataMap> merged_map = MergedDataMap::load(
118+
data_maps_["addmul"].get(), data_maps_["linear"].get());
119+
EXPECT_EQ(merged_map.error(), Error::Ok);
120+
}
121+
122+
TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) {
123+
Result<MergedDataMap> merged_map = MergedDataMap::load(
124+
data_maps_["addmul"].get(), data_maps_["addmul"].get());
125+
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
126+
}
127+
128+
TEST_F(MergedDataMapTest, CheckDataMapContents) {
129+
Result<MergedDataMap> merged_map = MergedDataMap::load(
130+
data_maps_["addmul"].get(), data_maps_["linear"].get());
131+
EXPECT_EQ(merged_map.error(), Error::Ok);
132+
133+
// Num keys.
134+
size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get();
135+
size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get();
136+
EXPECT_EQ(
137+
merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys);
138+
139+
// Load data into is not implemented for the merged data map.
140+
void* memory_block = malloc(10);
141+
ASSERT_EQ(
142+
Error::NotImplemented, merged_map->load_data_into("a", memory_block, 10));
143+
free(memory_block);
144+
145+
// API calls produce equivalent results.
146+
compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get());
147+
compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get());
148+
}

runtime/executor/test/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def define_common_targets(is_fbcode = False):
125125
"ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])",
126126
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
127127
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
128+
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
128129
}
129130

130131
runtime.cxx_test(
@@ -142,6 +143,19 @@ def define_common_targets(is_fbcode = False):
142143
env = modules_env,
143144
)
144145

146+
runtime.cxx_test(
147+
name = "merged_data_map_test",
148+
srcs = [
149+
"merged_data_map_test.cpp",
150+
],
151+
deps = [
152+
"//executorch/extension/data_loader:file_data_loader",
153+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
154+
"//executorch/runtime/executor:merged_data_map",
155+
],
156+
env = modules_env,
157+
)
158+
145159
runtime.cxx_test(
146160
name = "method_test",
147161
srcs = [

0 commit comments

Comments
 (0)