Skip to content

Commit d118a63

Browse files
authored
Add support in training_module to get attributes
Differential Revision: D78907316 Pull Request resolved: pytorch#12818
1 parent d3cbbf3 commit d118a63

File tree

4 files changed

+99
-5
lines changed

4 files changed

+99
-5
lines changed

extension/module/module.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,6 @@ runtime::Error Module::load_method(
231231
ET_NODISCARD runtime::Result<Method*> Module::method(
232232
const std::string& method_name) {
233233
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
234-
ET_CHECK_OR_RETURN_ERROR(
235-
methods_.count(method_name) > 0,
236-
InvalidArgument,
237-
"no such method in program: %s",
238-
method_name.c_str());
239234
return methods_[method_name].method.get();
240235
}
241236

extension/training/module/test/training_module_test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ TEST_F(TrainingModuleTest, JointGraphTest) {
8888
ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
8989
ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
9090
ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
91+
92+
// Test attributes for pte only model
93+
auto attributes_res = mod.named_attributes("forward");
94+
ASSERT_EQ(attributes_res.error(), Error::Ok);
95+
auto& attributes = attributes_res.get();
96+
ASSERT_EQ(attributes.size(), 0);
9197
}
9298

9399
TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
@@ -153,3 +159,43 @@ TEST_F(TrainingModuleTest, SeperateDataTest) {
153159
ASSERT_EQ(res.error(), Error::Ok);
154160
ASSERT_EQ(res.get().size(), 1);
155161
}
162+
163+
TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
164+
// Test the external constants are loaded correctly.
165+
const char* ptd_path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
166+
Result<FileDataLoader> data_map_loader_res = FileDataLoader::from(ptd_path);
167+
ASSERT_EQ(data_map_loader_res.error(), Error::Ok);
168+
169+
auto data_map_loader =
170+
std::make_unique<torch::executor::util::FileDataLoader>(
171+
std::move(data_map_loader_res.get()));
172+
173+
const char* pte_path = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
174+
Result<FileDataLoader> pte_loader_res = FileDataLoader::from(pte_path);
175+
ASSERT_EQ(pte_loader_res.error(), Error::Ok);
176+
177+
auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
178+
std::move(pte_loader_res.get()));
179+
180+
auto mod = executorch::extension::training::TrainingModule(
181+
std::move(pte_loader),
182+
nullptr,
183+
nullptr,
184+
nullptr,
185+
std::move(data_map_loader));
186+
187+
// Test Attributes for pte + ptd model containing external constants
188+
auto attributes_res = mod.named_attributes("forward");
189+
ASSERT_EQ(attributes_res.error(), Error::Ok);
190+
auto& attributes = attributes_res.get();
191+
ASSERT_EQ(attributes.size(), 2);
192+
ASSERT_NE(attributes.find("a"), attributes.end());
193+
ASSERT_NE(attributes.find("b"), attributes.end());
194+
195+
ASSERT_EQ(attributes.find("a")->second.sizes()[0], 2);
196+
ASSERT_EQ(attributes.find("a")->second.sizes()[1], 2);
197+
ASSERT_EQ(attributes.find("a")->second.dim(), 2);
198+
ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2);
199+
ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2);
200+
ASSERT_EQ(attributes.find("b")->second.dim(), 2);
201+
}

extension/training/module/training_module.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,41 @@ TrainingModule::named_gradients(const std::string& method_name) {
154154
return method_named_gradients_.at(method_name);
155155
}
156156

157+
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
158+
TrainingModule::named_attributes(const std::string& method_name) {
159+
// If we haven't seen this method before, populate the dict.
160+
if (method_named_attributes_.find(method_name) ==
161+
method_named_attributes_.end()) {
162+
method_named_attributes_.insert({method_name, {}});
163+
164+
// get method metadata
165+
auto meta_res = executorch::extension::Module::method_meta(method_name);
166+
if (!meta_res.ok()) {
167+
return meta_res.error();
168+
}
169+
// get method
170+
auto method_res = executorch::extension::Module::method(method_name);
171+
if (!method_res.ok()) {
172+
return method_res.error();
173+
}
174+
// get tensor by name
175+
for (int idx = 0; idx < meta_res->num_attributes(); idx++) {
176+
const auto tensor_res = meta_res->attribute_tensor_meta(idx);
177+
if (!tensor_res.ok()) {
178+
return tensor_res.error();
179+
}
180+
const auto tensorName = tensor_res.get().name();
181+
const auto attribute_res = (*method_res)->get_attribute(tensorName);
182+
if (!attribute_res.ok()) {
183+
return attribute_res.error();
184+
}
185+
method_named_attributes_.at(method_name)
186+
.insert({tensorName, attribute_res.get()});
187+
}
188+
}
189+
return method_named_attributes_.at(method_name);
190+
}
191+
157192
} // namespace training
158193
} // namespace extension
159194
} // namespace executorch

extension/training/module/training_module.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ class ET_EXPERIMENTAL TrainingModule final
9393
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
9494
named_gradients(const std::string& method_name);
9595

96+
/**
97+
* Retrieve the attributes for a method.
98+
*
99+
* @param[in] method_name The name of the method to get the
100+
* attributes for.
101+
*
102+
* @returns A Result object containing a map of the fully qualified name to
103+
* attribute tensor.
104+
*/
105+
ET_EXPERIMENTAL
106+
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
107+
named_attributes(const std::string& method_name);
108+
96109
private:
97110
std::unordered_map<
98111
std::string,
@@ -103,6 +116,11 @@ class ET_EXPERIMENTAL TrainingModule final
103116
std::string,
104117
std::map<std::string_view, executorch::aten::Tensor>>
105118
method_named_parameters_;
119+
120+
std::unordered_map<
121+
std::string,
122+
std::map<std::string_view, executorch::aten::Tensor>>
123+
method_named_attributes_;
106124
};
107125

108126
} // namespace training

0 commit comments

Comments
 (0)