Skip to content

Commit c105dda

Browse files
Util to load a named_data_map into a std::map
Differential Revision: D70186215 Pull Request resolved: #8737
1 parent 5b32a80 commit c105dda

File tree

5 files changed

+265
-0
lines changed

5 files changed

+265
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/training/module/state_dict_util.h>
10+
11+
namespace executorch {
12+
namespace extension {
13+
namespace training {
14+
15+
runtime::Result<std::map<std::string, executorch::extension::TensorPtr>>
16+
load_state_dict(const runtime::NamedDataMap& data_map) {
17+
std::map<std::string, executorch::extension::TensorPtr> state_dict;
18+
auto num_key_res = data_map.get_num_keys();
19+
if (!num_key_res.ok()) {
20+
return num_key_res.error();
21+
}
22+
for (size_t i = 0; i < num_key_res.get(); i++) {
23+
// get the key
24+
auto key_res = data_map.get_key(i);
25+
if (!key_res.ok()) {
26+
return key_res.error();
27+
}
28+
29+
// get the metadata
30+
auto metadata_res = data_map.get_metadata(key_res.get());
31+
if (!metadata_res.ok()) {
32+
return metadata_res.error();
33+
}
34+
35+
// get data blob
36+
void* data = nullptr;
37+
static constexpr size_t kMallocAlignment = alignof(std::max_align_t);
38+
if constexpr (kMallocAlignment < 8) {
39+
// Skip manually aligning the memory since PyTorch doesn't have dtypes >
40+
// 8 bytes wide, and I don't expect to ever encounter a platform where
41+
// malloc aligns to less than 8.
42+
ET_LOG(
43+
Error,
44+
"kMallocAlignment is too small: %zu. Cannot safely create buffer to load tensor. Please open an issue on https://github.com/pytorch/executorch/issues",
45+
kMallocAlignment);
46+
return runtime::Error::NotSupported;
47+
}
48+
49+
data = malloc(metadata_res->nbytes());
50+
if (data == nullptr && metadata_res->nbytes() != 0) {
51+
ET_LOG(Error, "Failed to allocate memory for tensor, malloc failed");
52+
return runtime::Error::MemoryAllocationFailed;
53+
}
54+
auto load_into_error =
55+
data_map.load_data_into(key_res.get(), data, metadata_res->nbytes());
56+
if (load_into_error != runtime::Error::Ok) {
57+
ET_LOG(
58+
Error,
59+
"Failed to load data into tensor, likely a malformed .ptd 0x%" PRIx32,
60+
static_cast<uint32_t>(load_into_error));
61+
return load_into_error;
62+
}
63+
64+
// Get metadata
65+
std::vector<executorch::aten::SizesType> sizes;
66+
for (auto x : metadata_res->sizes()) {
67+
sizes.push_back(x);
68+
}
69+
std::vector<executorch::aten::DimOrderType> dim_order;
70+
for (auto x : metadata_res->dim_order()) {
71+
dim_order.push_back(x);
72+
}
73+
std::vector<executorch::aten::StridesType> strides;
74+
for (auto stride_index = 0; stride_index < metadata_res->sizes().size();
75+
stride_index++) {
76+
if (stride_index == 0) {
77+
strides.push_back(1);
78+
} else {
79+
strides.insert(
80+
strides.begin(),
81+
sizes.at(stride_index) * strides.at(stride_index - 1));
82+
}
83+
}
84+
85+
// create tensor
86+
auto tensor = make_tensor_ptr(
87+
sizes,
88+
data,
89+
dim_order,
90+
strides,
91+
metadata_res->scalar_type(),
92+
exec_aten::TensorShapeDynamism::STATIC,
93+
[](void* ptr) {
94+
free(ptr);
95+
ptr = nullptr;
96+
});
97+
98+
// add to state dict
99+
state_dict.insert({std::string(key_res.get()), std::move(tensor)});
100+
}
101+
102+
return state_dict;
103+
}
104+
105+
} // namespace training
106+
} // namespace extension
107+
} // namespace executorch
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/named_data_map.h>
13+
#include <executorch/runtime/platform/compiler.h>
14+
15+
#include <map>
16+
#include <string>
17+
18+
namespace executorch {
19+
namespace extension {
20+
namespace training {
21+
22+
/**
23+
* Generate a map of string to tensor.
24+
*
25+
* @param data The NamedDataMap to load the tensors and names from.
26+
* @return A result containing a map of tensor names to tensors if
27+
* successful, an error otherwise.
28+
*/
29+
ET_EXPERIMENTAL
30+
runtime::Result<std::map<std::string, executorch::extension::TensorPtr>>
31+
load_state_dict(const runtime::NamedDataMap& data);
32+
33+
} // namespace training
34+
} // namespace extension
35+
} // namespace executorch

extension/training/module/targets.bzl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@ def define_common_targets():
77
TARGETS and BUCK files that call this function.
88
"""
99

10+
runtime.cxx_library(
11+
name = "state_dict_util",
12+
srcs = [
13+
"state_dict_util.cpp",
14+
],
15+
exported_headers = [
16+
"state_dict_util.h",
17+
],
18+
visibility = [
19+
"@EXECUTORCH_CLIENTS",
20+
],
21+
exported_deps = [
22+
"//executorch/runtime/core:named_data_map",
23+
"//executorch/extension/tensor:tensor",
24+
"//executorch/runtime/core:core",
25+
],
26+
)
27+
1028
for aten_mode in get_aten_mode_options():
1129
aten_suffix = ("_aten" if aten_mode else "")
1230

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
11+
#include <executorch/extension/training/module/state_dict_util.h>
12+
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/result.h>
15+
#include <executorch/runtime/platform/runtime.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
using namespace ::testing;
20+
using executorch::extension::FlatTensorDataMap;
21+
using executorch::extension::FlatTensorHeader;
22+
using executorch::runtime::DataLoader;
23+
using executorch::runtime::Error;
24+
using executorch::runtime::FreeableBuffer;
25+
using executorch::runtime::Result;
26+
using executorch::runtime::TensorLayout;
27+
using torch::executor::util::FileDataLoader;
28+
29+
class LoadStateDictTest : public ::testing::Test {
30+
protected:
31+
void SetUp() override {
32+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
33+
// first.
34+
executorch::runtime::runtime_init();
35+
36+
// Load data map.
37+
// The eager linear model is defined at:
38+
// //executorch/test/models/linear_model.py
39+
const char* path = std::getenv("ET_MODULE_LINEAR_DATA_PATH");
40+
Result<FileDataLoader> loader = FileDataLoader::from(path);
41+
ASSERT_EQ(loader.error(), Error::Ok);
42+
43+
Result<FreeableBuffer> header = loader->load(
44+
/*offset=*/0,
45+
FlatTensorHeader::kNumHeadBytes,
46+
/*segment_info=*/
47+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
48+
49+
ASSERT_EQ(header.error(), Error::Ok);
50+
51+
data_map_loader_ =
52+
std::make_unique<FileDataLoader>(std::move(loader.get()));
53+
}
54+
std::unique_ptr<FileDataLoader> data_map_loader_;
55+
};
56+
57+
TEST_F(LoadStateDictTest, LoadDataMap) {
58+
Result<FlatTensorDataMap> data_map =
59+
FlatTensorDataMap::load(data_map_loader_.get());
60+
EXPECT_EQ(data_map.error(), Error::Ok);
61+
62+
auto state_dict =
63+
executorch::extension::training::load_state_dict(data_map.get());
64+
ASSERT_TRUE(state_dict.ok());
65+
66+
EXPECT_EQ(state_dict->size(), 2);
67+
EXPECT_EQ(state_dict->at("a")->sizes().size(), 2);
68+
EXPECT_EQ(state_dict->at("a")->sizes()[0], 2);
69+
EXPECT_EQ(state_dict->at("a")->sizes()[1], 2);
70+
EXPECT_EQ(
71+
state_dict->at("a")->scalar_type(), torch::executor::ScalarType::Float);
72+
EXPECT_EQ(state_dict->at("a")->dim(), 2);
73+
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[0], 3.f);
74+
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[1], 3.f);
75+
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[2], 3.f);
76+
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[3], 3.f);
77+
78+
EXPECT_EQ(state_dict->size(), 2);
79+
EXPECT_EQ(state_dict->at("b")->sizes().size(), 2);
80+
EXPECT_EQ(state_dict->at("b")->sizes()[0], 2);
81+
EXPECT_EQ(state_dict->at("b")->sizes()[1], 2);
82+
EXPECT_EQ(
83+
state_dict->at("b")->scalar_type(), torch::executor::ScalarType::Float);
84+
EXPECT_EQ(state_dict->at("b")->dim(), 2);
85+
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[0], 2.f);
86+
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[1], 2.f);
87+
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[2], 2.f);
88+
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[3], 2.f);
89+
}

extension/training/module/test/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def define_common_targets(is_fbcode = False):
1717
# intentionally don't work in xplat (since they're host-only tools).
1818
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
1919
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
20+
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
21+
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
2022
}
2123

2224
runtime.cxx_test(
@@ -32,3 +34,17 @@ def define_common_targets(is_fbcode = False):
3234
],
3335
env = modules_env,
3436
)
37+
38+
runtime.cxx_test(
39+
name = "state_dict_util_test",
40+
srcs = [
41+
"state_dict_util_test.cpp",
42+
],
43+
deps = [
44+
"//executorch/extension/data_loader:file_data_loader",
45+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
46+
"//executorch/extension/training/module:state_dict_util",
47+
"//executorch/runtime/core/exec_aten:lib",
48+
],
49+
env = modules_env,
50+
)

0 commit comments

Comments
 (0)