Skip to content

Commit f2d403d

Browse files
committed
Introduce public MergedDataMap
Pull Request resolved: #14766 Add public merged data map. Module can use this to resolve multiple named data maps. Creating as a sep dependency rather than inside module/ so it can be used independently of module. (think there may be some other internal usages soon) Add support for BUCK and CMake. ghstack-source-id: 314009154 @exported-using-ghexport Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/)
1 parent 70ea661 commit f2d403d

File tree

18 files changed

+620
-2
lines changed

18 files changed

+620
-2
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
630630
list(APPEND _executorch_extensions extension_module_static)
631631
endif()
632632

633+
if(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP)
634+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/named_data_map)
635+
list(APPEND _executorch_extensions extension_named_data_map)
636+
endif()
637+
633638
if(EXECUTORCH_BUILD_EXTENSION_LLM)
634639
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
635640
set(SUPPORT_REGEX_LOOKAHEAD ON)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
17+
endif()
18+
19+
list(TRANSFORM _extension_named_data_map__srcs PREPEND "${EXECUTORCH_ROOT}/")
20+
# Create the library
21+
add_library(extension_named_data_map ${_extension_named_data_map__srcs})
22+
23+
# Link dependencies
24+
target_link_libraries(
25+
extension_named_data_map
26+
PUBLIC
27+
executorch_core
28+
)
29+
30+
target_include_directories(
31+
extension_named_data_map PUBLIC ${_common_include_directories}
32+
)
33+
34+
target_compile_options(extension_named_data_map PUBLIC ${_common_compile_options})
35+
36+
# Install libraries
37+
install(
38+
TARGETS extension_named_data_map
39+
EXPORT ExecuTorchTargets
40+
DESTINATION lib
41+
INCLUDES
42+
DESTINATION ${_common_include_directories}
43+
)
44+
45+
# Add tests if testing is enabled
46+
if(BUILD_TESTING)
47+
add_subdirectory(test)
48+
endif()

extension/named_data_map/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/named_data_map/merged_data_map.h>
10+
#include <executorch/runtime/core/data_loader.h>
11+
12+
#include <vector>
13+
14+
using executorch::aten::string_view;
15+
using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
16+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
17+
using executorch::runtime::Error;
18+
using executorch::runtime::FreeableBuffer;
19+
using executorch::runtime::Result;
20+
using executorch::runtime::Span;
21+
22+
namespace executorch::extension {
23+
24+
/*static*/ Result<MergedDataMap> MergedDataMap::load(
25+
Span<const NamedDataMap*> named_data_maps) {
26+
std::vector<const NamedDataMap*> valid_data_maps;
27+
for (auto i : c10::irange(named_data_maps.size())) {
28+
if (named_data_maps[i] != nullptr &&
29+
named_data_maps[i]->get_num_keys().get() > 0) {
30+
valid_data_maps.push_back(named_data_maps[i]);
31+
}
32+
}
33+
ET_CHECK_OR_RETURN_ERROR(
34+
!valid_data_maps.empty(),
35+
InvalidArgument,
36+
"No non-empty named data maps provided to merge");
37+
38+
// Check for duplicate keys.
39+
std::unordered_map<std::string, uint32_t> key_to_map_index;
40+
for (uint32_t i = 0; i < valid_data_maps.size(); i++) {
41+
const auto cur_map = valid_data_maps[i];
42+
uint32_t num_keys = cur_map->get_num_keys().get();
43+
for (uint32_t j = 0; j < num_keys; ++j) {
44+
const auto cur_key = cur_map->get_key(j).get();
45+
const auto [it, inserted] = key_to_map_index.emplace(cur_key, i);
46+
ET_CHECK_OR_RETURN_ERROR(
47+
inserted,
48+
InvalidArgument,
49+
"Duplicate key %s in named data maps at index %u and %u",
50+
cur_key,
51+
it->second,
52+
i);
53+
}
54+
}
55+
return MergedDataMap(std::move(valid_data_maps), std::move(key_to_map_index));
56+
}
57+
58+
ET_NODISCARD Result<const TensorLayout> MergedDataMap::get_tensor_layout(
59+
string_view key) const {
60+
const auto it = key_to_map_index_.find(key.data());
61+
ET_CHECK_OR_RETURN_ERROR(
62+
it != key_to_map_index_.end(),
63+
NotFound,
64+
"Key %s not found in named data maps",
65+
key.data());
66+
67+
return named_data_maps_.at(it->second)->get_tensor_layout(key);
68+
}
69+
70+
ET_NODISCARD
71+
Result<FreeableBuffer> MergedDataMap::get_data(string_view key) const {
72+
const auto it = key_to_map_index_.find(key.data());
73+
ET_CHECK_OR_RETURN_ERROR(
74+
it != key_to_map_index_.end(),
75+
NotFound,
76+
"Key %s not found in named data maps",
77+
key.data());
78+
return named_data_maps_.at(it->second)->get_data(key);
79+
}
80+
81+
ET_NODISCARD Error MergedDataMap::load_data_into(
82+
string_view key,
83+
void* buffer,
84+
size_t size) const {
85+
const auto it = key_to_map_index_.find(key.data());
86+
ET_CHECK_OR_RETURN_ERROR(
87+
it != key_to_map_index_.end(),
88+
NotFound,
89+
"Key %s not found in named data maps",
90+
key.data());
91+
return named_data_maps_.at(it->second)->load_data_into(key, buffer, size);
92+
}
93+
94+
ET_NODISCARD Result<uint32_t> MergedDataMap::get_num_keys() const {
95+
return key_to_map_index_.size();
96+
}
97+
98+
ET_NODISCARD Result<const char*> MergedDataMap::get_key(uint32_t index) const {
99+
uint32_t total_num_keys = get_num_keys().get();
100+
ET_CHECK_OR_RETURN_ERROR(
101+
index < total_num_keys,
102+
InvalidArgument,
103+
"Index %u out of range of size %u",
104+
index,
105+
total_num_keys);
106+
for (auto i : c10::irange(named_data_maps_.size())) {
107+
auto num_keys = named_data_maps_[i]->get_num_keys().get();
108+
if (index < num_keys) {
109+
return named_data_maps_[i]->get_key(index);
110+
}
111+
index -= num_keys;
112+
}
113+
// Shouldn't reach here.
114+
return Error::Internal;
115+
}
116+
} // namespace executorch::extension
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/named_data_map.h>
12+
13+
#include <vector>
14+
15+
namespace executorch {
16+
namespace extension {
17+
/**
18+
* A NamedDataMap implementation that wraps other NamedDataMaps.
19+
*/
20+
class MergedDataMap final
21+
: public executorch::ET_RUNTIME_NAMESPACE::NamedDataMap {
22+
public:
23+
/**
24+
* Creates a new NamedDataMap that takes in other data maps.
25+
*
26+
* @param[in] data_maps vector of NamedDataMap pointers to merge.
27+
* Note: the data maps must outlive the MergedDataMap instance.
28+
*/
29+
static executorch::runtime::Result<MergedDataMap>
30+
load(executorch::runtime::Span<
31+
const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*> named_data_maps);
32+
33+
/**
34+
* Retrieve the tensor_layout for the specified key.
35+
*
36+
* @param[in] key The name of the tensor to get metadata on.
37+
*
38+
* @return Error::NotFound if the key is not present.
39+
*/
40+
ET_NODISCARD
41+
executorch::runtime::Result<
42+
const executorch::ET_RUNTIME_NAMESPACE::TensorLayout>
43+
get_tensor_layout(executorch::aten::string_view key) const override;
44+
45+
/**
46+
* Retrieve read-only data for the specified key.
47+
*
48+
* @param[in] key The name of the tensor to get data on.
49+
*
50+
* @return error if the key is not present or data cannot be loaded.
51+
*/
52+
ET_NODISCARD
53+
executorch::runtime::Result<executorch::runtime::FreeableBuffer> get_data(
54+
executorch::aten::string_view key) const override;
55+
56+
/**
57+
* Loads the data of the specified tensor into the provided buffer.
58+
*
59+
* @param[in] key The name of the tensor to get the data of.
60+
* @param[in] buffer The buffer to load data into. Must point to at least
61+
* `size` bytes of memory.
62+
* @param[in] size The number of bytes to load.
63+
*
64+
* @returns an Error indicating if the load was successful.
65+
*/
66+
ET_NODISCARD executorch::runtime::Error load_data_into(
67+
executorch::aten::string_view key,
68+
void* buffer,
69+
size_t size) const override;
70+
71+
/**
72+
* @returns The number of keys in the map.
73+
*/
74+
ET_NODISCARD executorch::runtime::Result<uint32_t> get_num_keys()
75+
const override;
76+
/**
77+
* @returns The key at the specified index, error if index out of bounds.
78+
*/
79+
ET_NODISCARD executorch::runtime::Result<const char*> get_key(
80+
uint32_t index) const override;
81+
82+
MergedDataMap(MergedDataMap&&) noexcept = default;
83+
84+
~MergedDataMap() override = default;
85+
86+
private:
87+
MergedDataMap(
88+
std::vector<const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*>
89+
named_data_maps,
90+
std::unordered_map<std::string, uint32_t> key_to_map_index)
91+
: named_data_maps_(std::move(named_data_maps)),
92+
key_to_map_index_(std::move(key_to_map_index)) {}
93+
94+
// Not copyable or assignable.
95+
MergedDataMap(const MergedDataMap& rhs) = delete;
96+
MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete;
97+
MergedDataMap& operator=(const MergedDataMap& rhs) = delete;
98+
99+
std::vector<const executorch::ET_RUNTIME_NAMESPACE::NamedDataMap*>
100+
named_data_maps_;
101+
102+
// Map from key to index in the named_data_maps_ vector.
103+
std::unordered_map<std::string, uint32_t> key_to_map_index_;
104+
};
105+
106+
} // namespace extension
107+
} // namespace executorch
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_options", "runtime")
2+
3+
def define_common_targets():
4+
for aten_mode in get_aten_mode_options():
5+
aten_suffix = "_aten" if aten_mode else ""
6+
runtime.cxx_library(
7+
name = "merged_data_map" + aten_suffix,
8+
srcs = [
9+
"merged_data_map.cpp",
10+
],
11+
exported_headers = [
12+
"merged_data_map.h",
13+
],
14+
visibility = [
15+
"@EXECUTORCH_CLIENTS",
16+
],
17+
deps = [
18+
"//executorch/runtime/core:named_data_map",
19+
"//executorch/runtime/core:core",
20+
],
21+
)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Please this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
12+
cmake_minimum_required(VERSION 3.19)
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
17+
endif()
18+
19+
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
20+
21+
add_custom_command(
22+
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
23+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
24+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
25+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
26+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
27+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
28+
COMMAND
29+
${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul,ModuleLinear,ModuleSimpleTrain"
30+
--external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}"
31+
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
32+
)
33+
34+
add_custom_target(
35+
extension_named_data_map_test_resources
36+
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
37+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
38+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
39+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
40+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
41+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
42+
)
43+
44+
set(test_env
45+
"ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
46+
"ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
47+
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
48+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
49+
"ET_MODULE_SIMPLE_TRAIN_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.pte"
50+
"ET_MODULE_SIMPLE_TRAIN_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleSimpleTrainProgram.ptd"
51+
)
52+
53+
set(_test_srcs merged_data_map_test.cpp)
54+
55+
et_cxx_test(
56+
extension_named_data_map_test SOURCES ${_test_srcs} EXTRA_LIBS
57+
extension_named_data_map extension_flat_tensor extension_data_loader
58+
)
59+
60+
add_dependencies(
61+
extension_named_data_map_test extension_named_data_map extension_named_data_map_test_resources
62+
)
63+
set_property(TEST extension_named_data_map_test PROPERTY ENVIRONMENT ${test_env})
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets(is_fbcode=True)

0 commit comments

Comments
 (0)