Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ runtime::Error Module::load_method(
ET_NODISCARD runtime::Result<Method*> Module::method(
const std::string& method_name) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
ET_CHECK_OR_RETURN_ERROR(
methods_.count(method_name) > 0,
InvalidArgument,
"no such method in program: %s",
method_name.c_str());
return methods_[method_name].method.get();
}

Expand Down
46 changes: 46 additions & 0 deletions extension/training/module/test/training_module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ TEST_F(TrainingModuleTest, JointGraphTest) {
ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);

// Test attributes for pte only model
auto attributes_res = mod.named_attributes("forward");
ASSERT_EQ(attributes_res.error(), Error::Ok);
auto& attributes = attributes_res.get();
ASSERT_EQ(attributes.size(), 0);
}

TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
Expand Down Expand Up @@ -153,3 +159,43 @@ TEST_F(TrainingModuleTest, SeperateDataTest) {
ASSERT_EQ(res.error(), Error::Ok);
ASSERT_EQ(res.get().size(), 1);
}

TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
// Test the external constants are loaded correctly.
const char* ptd_path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
Result<FileDataLoader> data_map_loader_res = FileDataLoader::from(ptd_path);
ASSERT_EQ(data_map_loader_res.error(), Error::Ok);

auto data_map_loader =
std::make_unique<torch::executor::util::FileDataLoader>(
std::move(data_map_loader_res.get()));

const char* pte_path = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
Result<FileDataLoader> pte_loader_res = FileDataLoader::from(pte_path);
ASSERT_EQ(pte_loader_res.error(), Error::Ok);

auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
std::move(pte_loader_res.get()));

auto mod = executorch::extension::training::TrainingModule(
std::move(pte_loader),
nullptr,
nullptr,
nullptr,
std::move(data_map_loader));

// Test Attributes for pte + ptd model containing external constants
auto attributes_res = mod.named_attributes("forward");
ASSERT_EQ(attributes_res.error(), Error::Ok);
auto& attributes = attributes_res.get();
ASSERT_EQ(attributes.size(), 2);
ASSERT_NE(attributes.find("a"), attributes.end());
ASSERT_NE(attributes.find("b"), attributes.end());

ASSERT_EQ(attributes.find("a")->second.sizes()[0], 2);
ASSERT_EQ(attributes.find("a")->second.sizes()[1], 2);
ASSERT_EQ(attributes.find("a")->second.dim(), 2);
ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2);
ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2);
ASSERT_EQ(attributes.find("b")->second.dim(), 2);
}
35 changes: 35 additions & 0 deletions extension/training/module/training_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,41 @@ TrainingModule::named_gradients(const std::string& method_name) {
return method_named_gradients_.at(method_name);
}

runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
TrainingModule::named_attributes(const std::string& method_name) {
// If we haven't seen this method before, populate the dict.
if (method_named_attributes_.find(method_name) ==
method_named_attributes_.end()) {
method_named_attributes_.insert({method_name, {}});

// get method metadata
auto meta_res = executorch::extension::Module::method_meta(method_name);
if (!meta_res.ok()) {
return meta_res.error();
}
// get method
auto method_res = executorch::extension::Module::method(method_name);
if (!method_res.ok()) {
return method_res.error();
}
// get tensor by name
for (int idx = 0; idx < meta_res->num_attributes(); idx++) {
const auto tensor_res = meta_res->attribute_tensor_meta(idx);
if (!tensor_res.ok()) {
return tensor_res.error();
}
const auto tensorName = tensor_res.get().name();
const auto attribute_res = (*method_res)->get_attribute(tensorName);
if (!attribute_res.ok()) {
return attribute_res.error();
}
method_named_attributes_.at(method_name)
.insert({tensorName, attribute_res.get()});
}
}
return method_named_attributes_.at(method_name);
}

} // namespace training
} // namespace extension
} // namespace executorch
18 changes: 18 additions & 0 deletions extension/training/module/training_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ class ET_EXPERIMENTAL TrainingModule final
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
named_gradients(const std::string& method_name);

/**
* Retrieve the attributes for a method.
*
* @param[in] method_name The name of the method to get the
* attributes for.
*
* @returns A Result object containing a map of the fully qualified name to
* attribute tensor.
*/
ET_EXPERIMENTAL
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
named_attributes(const std::string& method_name);

private:
std::unordered_map<
std::string,
Expand All @@ -103,6 +116,11 @@ class ET_EXPERIMENTAL TrainingModule final
std::string,
std::map<std::string_view, executorch::aten::Tensor>>
method_named_parameters_;

std::unordered_map<
std::string,
std::map<std::string_view, executorch::aten::Tensor>>
method_named_attributes_;
};

} // namespace training
Expand Down
Loading