Skip to content

Commit 15c35a8

Browse files
committed
Introduce public MergedDataMap
Pull Request resolved: #14766 Add public merged data map. Module can use this to resolve multiple named data maps. Creating as a sep dependency rather than inside module/ so it can be used independently of module. (think there may be some other internal usages soon) Add support for BUCK and CMake. ghstack-source-id: 314023106 @exported-using-ghexport Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/)
1 parent 70ea661 commit 15c35a8

File tree

21 files changed

+634
-1
lines changed

21 files changed

+634
-1
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
630630
list(APPEND _executorch_extensions extension_module_static)
631631
endif()
632632

633+
if(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP)
634+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/named_data_map)
635+
list(APPEND _executorch_extensions extension_named_data_map)
636+
endif()
637+
633638
if(EXECUTORCH_BUILD_EXTENSION_LLM)
634639
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
635640
set(SUPPORT_REGEX_LOOKAHEAD ON)

backends/qualcomm/scripts/build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8484
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
8585
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
8686
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
87+
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \
8788
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
8889
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \
8990
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
17+
endif()
18+
19+
list(TRANSFORM _extension_named_data_map__srcs PREPEND "${EXECUTORCH_ROOT}/")
20+
# Create the library
21+
add_library(extension_named_data_map ${_extension_named_data_map__srcs})
22+
23+
# Link dependencies
24+
target_link_libraries(extension_named_data_map PUBLIC executorch_core)
25+
26+
target_include_directories(
27+
extension_named_data_map PUBLIC ${_common_include_directories}
28+
)
29+
30+
target_compile_options(
31+
extension_named_data_map PUBLIC ${_common_compile_options}
32+
)
33+
34+
# Install libraries
35+
install(
36+
TARGETS extension_named_data_map
37+
EXPORT ExecuTorchTargets
38+
DESTINATION lib
39+
INCLUDES
40+
DESTINATION ${_common_include_directories}
41+
)
42+
43+
# Add tests if testing is enabled
44+
if(BUILD_TESTING)
45+
add_subdirectory(test)
46+
endif()

extension/named_data_map/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/named_data_map/merged_data_map.h>
10+
#include <executorch/runtime/core/data_loader.h>
11+
12+
#include <unordered_map>
13+
#include <vector>
14+
15+
using executorch::aten::string_view;
16+
using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
17+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
18+
using executorch::runtime::Error;
19+
using executorch::runtime::FreeableBuffer;
20+
using executorch::runtime::Result;
21+
using executorch::runtime::Span;
22+
23+
namespace executorch::extension {
24+
25+
/*static*/ Result<MergedDataMap> MergedDataMap::load(
26+
Span<const NamedDataMap*> named_data_maps) {
27+
std::vector<const NamedDataMap*> valid_data_maps;
28+
for (auto i : c10::irange(named_data_maps.size())) {
29+
if (named_data_maps[i] != nullptr &&
30+
named_data_maps[i]->get_num_keys().get() > 0) {
31+
valid_data_maps.push_back(named_data_maps[i]);
32+
}
33+
}
34+
ET_CHECK_OR_RETURN_ERROR(
35+
!valid_data_maps.empty(),
36+
InvalidArgument,
37+
"No non-empty named data maps provided to merge");
38+
39+
// Check for duplicate keys.
40+
std::unordered_map<std::string, uint32_t> key_to_map_index;
41+
for (uint32_t i = 0; i < valid_data_maps.size(); i++) {
42+
const auto cur_map = valid_data_maps[i];
43+
uint32_t num_keys = cur_map->get_num_keys().get();
44+
for (uint32_t j = 0; j < num_keys; ++j) {
45+
const auto cur_key = cur_map->get_key(j).get();
46+
const auto [it, inserted] = key_to_map_index.emplace(cur_key, i);
47+
ET_CHECK_OR_RETURN_ERROR(
48+
inserted,
49+
InvalidArgument,
50+
"Duplicate key %s in named data maps at index %u and %u",
51+
cur_key,
52+
it->second,
53+
i);
54+
}
55+
}
56+
return MergedDataMap(std::move(valid_data_maps), std::move(key_to_map_index));
57+
}
58+
59+
ET_NODISCARD Result<const TensorLayout> MergedDataMap::get_tensor_layout(
60+
string_view key) const {
61+
const auto it = key_to_map_index_.find(key.data());
62+
ET_CHECK_OR_RETURN_ERROR(
63+
it != key_to_map_index_.end(),
64+
NotFound,
65+
"Key %s not found in named data maps",
66+
key.data());
67+
68+
return named_data_maps_.at(it->second)->get_tensor_layout(key);
69+
}
70+
71+
ET_NODISCARD
72+
Result<FreeableBuffer> MergedDataMap::get_data(string_view key) const {
73+
const auto it = key_to_map_index_.find(key.data());
74+
ET_CHECK_OR_RETURN_ERROR(
75+
it != key_to_map_index_.end(),
76+
NotFound,
77+
"Key %s not found in named data maps",
78+
key.data());
79+
return named_data_maps_.at(it->second)->get_data(key);
80+
}
81+
82+
ET_NODISCARD Error MergedDataMap::load_data_into(
83+
string_view key,
84+
void* buffer,
85+
size_t size) const {
86+
const auto it = key_to_map_index_.find(key.data());
87+
ET_CHECK_OR_RETURN_ERROR(
88+
it != key_to_map_index_.end(),
89+
NotFound,
90+
"Key %s not found in named data maps",
91+
key.data());
92+
return named_data_maps_.at(it->second)->load_data_into(key, buffer, size);
93+
}
94+
95+
ET_NODISCARD Result<uint32_t> MergedDataMap::get_num_keys() const {
96+
return key_to_map_index_.size();
97+
}
98+
99+
ET_NODISCARD Result<const char*> MergedDataMap::get_key(uint32_t index) const {
100+
uint32_t total_num_keys = get_num_keys().get();
101+
ET_CHECK_OR_RETURN_ERROR(
102+
index < total_num_keys,
103+
InvalidArgument,
104+
"Index %u out of range of size %u",
105+
index,
106+
total_num_keys);
107+
for (auto i : c10::irange(named_data_maps_.size())) {
108+
auto num_keys = named_data_maps_[i]->get_num_keys().get();
109+
if (index < num_keys) {
110+
return named_data_maps_[i]->get_key(index);
111+
}
112+
index -= num_keys;
113+
}
114+
// Shouldn't reach here.
115+
return Error::Internal;
116+
}
117+
} // namespace executorch::extension
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
#include <unordered_map>
14+
#include <vector>
15+
16+
namespace executorch {
17+
namespace extension {
18+
/**
19+
* A NamedDataMap implementation that wraps other NamedDataMaps.
20+
*/
21+
class MergedDataMap final
22+
: public executorch::ET_RUNTIME_NAMESPACE::NamedDataMap {
23+
public:
24+
/**
25+
* Creates a new NamedDataMap that takes in other data maps.
26+
*
27+
* @param[in] data_maps vector of NamedDataMap pointers to merge.
28+
* Note: the data maps must outlive the MergedDataMap instance.
29+
*/
30+
static executorch::runtime::Result<MergedDataMap>
31+
load(executorch::runtime::Span<
32+
const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*> named_data_maps);
33+
34+
/**
35+
* Retrieve the tensor_layout for the specified key.
36+
*
37+
* @param[in] key The name of the tensor to get metadata on.
38+
*
39+
* @return Error::NotFound if the key is not present.
40+
*/
41+
ET_NODISCARD
42+
executorch::runtime::Result<
43+
const executorch::ET_RUNTIME_NAMESPACE::TensorLayout>
44+
get_tensor_layout(executorch::aten::string_view key) const override;
45+
46+
/**
47+
* Retrieve read-only data for the specified key.
48+
*
49+
* @param[in] key The name of the tensor to get data on.
50+
*
51+
* @return error if the key is not present or data cannot be loaded.
52+
*/
53+
ET_NODISCARD
54+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
55+
executorch::aten::string_view key) const override;
56+
57+
/**
58+
* Loads the data of the specified tensor into the provided buffer.
59+
*
60+
* @param[in] key The name of the tensor to get the data of.
61+
* @param[in] buffer The buffer to load data into. Must point to at least
62+
* `size` bytes of memory.
63+
* @param[in] size The number of bytes to load.
64+
*
65+
* @returns an Error indicating if the load was successful.
66+
*/
67+
ET_NODISCARD executorch::runtime::Error load_data_into(
68+
executorch::aten::string_view key,
69+
void* buffer,
70+
size_t size) const override;
71+
72+
/**
73+
* @returns The number of keys in the map.
74+
*/
75+
ET_NODISCARD executorch::runtime::Result<uint32_t> get_num_keys()
76+
const override;
77+
/**
78+
* @returns The key at the specified index, error if index out of bounds.
79+
*/
80+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
81+
uint32_t index) const override;
82+
83+
MergedDataMap(MergedDataMap&&) noexcept = default;
84+
85+
~MergedDataMap() override = default;
86+
87+
private:
88+
MergedDataMap(
89+
std::vector<const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*>
90+
named_data_maps,
91+
std::unordered_map<std::string, uint32_t> key_to_map_index)
92+
: named_data_maps_(std::move(named_data_maps)),
93+
key_to_map_index_(std::move(key_to_map_index)) {}
94+
95+
// Not copyable or assignable.
96+
MergedDataMap(const MergedDataMap& rhs) = delete;
97+
MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete;
98+
MergedDataMap& operator=(const MergedDataMap& rhs) = delete;
99+
100+
std::vector<const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*>
101+
named_data_maps_;
102+
103+
// Map from key to index in the named_data_maps_ vector.
104+
std::unordered_map<std::string, uint32_t> key_to_map_index_;
105+
};
106+
107+
} // namespace extension
108+
} // namespace executorch
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_options", "runtime")
2+
3+
def define_common_targets():
4+
for aten_mode in get_aten_mode_options():
5+
aten_suffix = "_aten" if aten_mode else ""
6+
runtime.cxx_library(
7+
name = "merged_data_map" + aten_suffix,
8+
srcs = [
9+
"merged_data_map.cpp",
10+
],
11+
exported_headers = [
12+
"merged_data_map.h",
13+
],
14+
visibility = [
15+
"@EXECUTORCH_CLIENTS",
16+
],
17+
deps = [
18+
"//executorch/runtime/core:named_data_map",
19+
"//executorch/runtime/core:core",
20+
],
21+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
17+
endif()
18+
19+
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
20+
21+
add_custom_command(
22+
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
23+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
24+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
25+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
26+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
27+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
28+
COMMAND
29+
${PYTHON_EXECUTABLE} -m test.models.export_program --modules
30+
"ModuleAddMul,ModuleLinear,ModuleSimpleTrain" --external-constants --outdir
31+
"${CMAKE_CURRENT_BINARY_DIR}"
32+
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
33+
)
34+
35+
add_custom_target(
36+
extension_named_data_map_test_resources
37+
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
38+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
39+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
40+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
41+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
42+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
43+
)
44+
45+
set(test_env
46+
"ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
47+
"ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
48+
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
49+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
50+
"ET_MODULE_SIMPLE_TRAIN_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
51+
"ET_MODULE_SIMPLE_TRAIN_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
52+
)
53+
54+
set(_test_srcs merged_data_map_test.cpp)
55+
56+
et_cxx_test(
57+
extension_named_data_map_test
58+
SOURCES
59+
${_test_srcs}
60+
EXTRA_LIBS
61+
extension_named_data_map
62+
extension_flat_tensor
63+
extension_data_loader
64+
)
65+
66+
add_dependencies(
67+
extension_named_data_map_test extension_named_data_map
68+
extension_named_data_map_test_resources
69+
)
70+
set_property(
71+
TEST extension_named_data_map_test PROPERTY ENVIRONMENT ${test_env}
72+
)

0 commit comments

Comments
 (0)