Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions extension/training/module/state_dict_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/training/module/state_dict_util.h>

namespace executorch {
namespace extension {
namespace training {

runtime::Result<std::map<std::string, executorch::extension::TensorPtr>>
load_state_dict(const runtime::NamedDataMap& data_map) {
std::map<std::string, executorch::extension::TensorPtr> state_dict;
auto num_key_res = data_map.get_num_keys();
if (!num_key_res.ok()) {
return num_key_res.error();
}
for (size_t i = 0; i < num_key_res.get(); i++) {
// get the key
auto key_res = data_map.get_key(i);
if (!key_res.ok()) {
return key_res.error();
}

// get the metadata
auto metadata_res = data_map.get_metadata(key_res.get());
if (!metadata_res.ok()) {
return metadata_res.error();
}

// get data blob
void* data = nullptr;
static constexpr size_t kMallocAlignment = alignof(std::max_align_t);
if constexpr (kMallocAlignment < 8) {
// Skip manually aligning the memory since PyTorch doesn't have dtypes >
// 8 bytes wide, and I don't expect to ever encounter a platform where
// malloc aligns to less than 8.
ET_LOG(
Error,
"kMallocAlignment is too small: %zu. Cannot safely create buffer to load tensor. Please open an issue on https://github.com/pytorch/executorch/issues",
kMallocAlignment);
return runtime::Error::NotSupported;
}

data = malloc(metadata_res->nbytes());
if (data == nullptr && metadata_res->nbytes() != 0) {
ET_LOG(Error, "Failed to allocate memory for tensor, malloc failed");
return runtime::Error::MemoryAllocationFailed;
}
auto load_into_error =
data_map.load_data_into(key_res.get(), data, metadata_res->nbytes());
if (load_into_error != runtime::Error::Ok) {
ET_LOG(
Error,
"Failed to load data into tensor, likely a malformed .ptd 0x%" PRIx32,
static_cast<uint32_t>(load_into_error));
return load_into_error;
}

// Get metadata
std::vector<executorch::aten::SizesType> sizes;
for (auto x : metadata_res->sizes()) {
sizes.push_back(x);
}
std::vector<executorch::aten::DimOrderType> dim_order;
for (auto x : metadata_res->dim_order()) {
dim_order.push_back(x);
}
std::vector<executorch::aten::StridesType> strides;
for (auto stride_index = 0; stride_index < metadata_res->sizes().size();
stride_index++) {
if (stride_index == 0) {
strides.push_back(1);
} else {
strides.insert(
strides.begin(),
sizes.at(stride_index) * strides.at(stride_index - 1));
}
}

// create tensor
auto tensor = make_tensor_ptr(
sizes,
data,
dim_order,
strides,
metadata_res->scalar_type(),
exec_aten::TensorShapeDynamism::STATIC,
[](void* ptr) {
free(ptr);
ptr = nullptr;
});

// add to state dict
state_dict.insert({std::string(key_res.get()), std::move(tensor)});
}

return state_dict;
}

} // namespace training
} // namespace extension
} // namespace executorch
35 changes: 35 additions & 0 deletions extension/training/module/state_dict_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/named_data_map.h>
#include <executorch/runtime/platform/compiler.h>

#include <map>
#include <string>

namespace executorch {
namespace extension {
namespace training {

/**
* Generate a map of string to tensor.
*
* @param data The NamedDataMap to load the tensors and names from.
* @return A result containing a map of tensor names to tensors if
* successful, an error otherwise.
*/
ET_EXPERIMENTAL
runtime::Result<std::map<std::string, executorch::extension::TensorPtr>>
load_state_dict(const runtime::NamedDataMap& data);

} // namespace training
} // namespace extension
} // namespace executorch
18 changes: 18 additions & 0 deletions extension/training/module/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ def define_common_targets():
TARGETS and BUCK files that call this function.
"""

runtime.cxx_library(
name = "state_dict_util",
srcs = [
"state_dict_util.cpp",
],
exported_headers = [
"state_dict_util.h",
],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/runtime/core:named_data_map",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
],
)

for aten_mode in get_aten_mode_options():
aten_suffix = ("_aten" if aten_mode else "")

Expand Down
89 changes: 89 additions & 0 deletions extension/training/module/test/state_dict_util_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
#include <executorch/extension/training/module/state_dict_util.h>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/runtime.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::extension::FlatTensorDataMap;
using executorch::extension::FlatTensorHeader;
using executorch::runtime::DataLoader;
using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::Result;
using executorch::runtime::TensorLayout;
using torch::executor::util::FileDataLoader;

class LoadStateDictTest : public ::testing::Test {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();

// Load data map.
// The eager linear model is defined at:
// //executorch/test/models/linear_model.py
const char* path = std::getenv("ET_MODULE_LINEAR_DATA_PATH");
Result<FileDataLoader> loader = FileDataLoader::from(path);
ASSERT_EQ(loader.error(), Error::Ok);

Result<FreeableBuffer> header = loader->load(
/*offset=*/0,
FlatTensorHeader::kNumHeadBytes,
/*segment_info=*/
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));

ASSERT_EQ(header.error(), Error::Ok);

data_map_loader_ =
std::make_unique<FileDataLoader>(std::move(loader.get()));
}
std::unique_ptr<FileDataLoader> data_map_loader_;
};

TEST_F(LoadStateDictTest, LoadDataMap) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

auto state_dict =
executorch::extension::training::load_state_dict(data_map.get());
ASSERT_TRUE(state_dict.ok());

EXPECT_EQ(state_dict->size(), 2);
EXPECT_EQ(state_dict->at("a")->sizes().size(), 2);
EXPECT_EQ(state_dict->at("a")->sizes()[0], 2);
EXPECT_EQ(state_dict->at("a")->sizes()[1], 2);
EXPECT_EQ(
state_dict->at("a")->scalar_type(), torch::executor::ScalarType::Float);
EXPECT_EQ(state_dict->at("a")->dim(), 2);
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[0], 3.f);
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[1], 3.f);
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[2], 3.f);
EXPECT_EQ(state_dict->at("a")->const_data_ptr<float>()[3], 3.f);

EXPECT_EQ(state_dict->size(), 2);
EXPECT_EQ(state_dict->at("b")->sizes().size(), 2);
EXPECT_EQ(state_dict->at("b")->sizes()[0], 2);
EXPECT_EQ(state_dict->at("b")->sizes()[1], 2);
EXPECT_EQ(
state_dict->at("b")->scalar_type(), torch::executor::ScalarType::Float);
EXPECT_EQ(state_dict->at("b")->dim(), 2);
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[0], 2.f);
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[1], 2.f);
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[2], 2.f);
EXPECT_EQ(state_dict->at("b")->const_data_ptr<float>()[3], 2.f);
}
16 changes: 16 additions & 0 deletions extension/training/module/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def define_common_targets(is_fbcode = False):
# intentionally don't work in xplat (since they're host-only tools).
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
}

runtime.cxx_test(
Expand All @@ -32,3 +34,17 @@ def define_common_targets(is_fbcode = False):
],
env = modules_env,
)

runtime.cxx_test(
name = "state_dict_util_test",
srcs = [
"state_dict_util_test.cpp",
],
deps = [
"//executorch/extension/data_loader:file_data_loader",
"//executorch/extension/flat_tensor:flat_tensor_data_map",
"//executorch/extension/training/module:state_dict_util",
"//executorch/runtime/core/exec_aten:lib",
],
env = modules_env,
)
Loading