Skip to content

Commit 4aee11d

Browse files
silverguofacebook-github-bot
authored andcommitted
Add support in training_module to get attributes
Summary: As title Differential Revision: D78907316
1 parent ef10a35 commit 4aee11d

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

extension/module/module.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,7 @@ runtime::Error Module::load_method(
230230

231231
ET_NODISCARD runtime::Result<Method*> Module::method(
232232
const std::string& method_name) {
233-
ET_CHECK_OR_RETURN_ERROR(
234-
methods_.count(method_name) > 0,
235-
InvalidArgument,
236-
"no such method in program: %s",
237-
method_name.c_str());
233+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
238234
return methods_[method_name].method.get();
239235
}
240236

extension/training/module/test/training_module_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ TEST_F(TrainingModuleTest, JointGraphTest) {
8888
ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
8989
ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
9090
ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
91+
92+
// Test attributes for pte only model
93+
auto attributes_res = mod.named_attributes("forward");
94+
ASSERT_EQ(attributes_res.error(), Error::Ok);
95+
auto& attributes = attributes_res.get();
96+
ASSERT_EQ(attributes.size(), 0);
9197
}
9298

9399
TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
@@ -152,4 +158,18 @@ TEST_F(TrainingModuleTest, SeperateDataTest) {
152158
auto res = mod.execute_forward_backward("forward", inputs);
153159
ASSERT_EQ(res.error(), Error::Ok);
154160
ASSERT_EQ(res.get().size(), 1);
161+
162+
// Test Attributes for pte + ptd model
163+
auto attributes_res = mod.named_attributes("forward");
164+
ASSERT_EQ(attributes_res.error(), Error::Ok);
165+
auto& attributes = attributes_res.get();
166+
ASSERT_EQ(attributes.size(), 2);
167+
ASSERT_NE(attributes.find("linear.weight"), attributes.end());
168+
ASSERT_NE(attributes.find("linear.bias"), attributes.end());
169+
170+
ASSERT_EQ(attributes.find("linear.weight")->second.sizes()[0], 3);
171+
ASSERT_EQ(attributes.find("linear.weight")->second.sizes()[1], 3);
172+
ASSERT_EQ(attributes.find("linear.weight")->second.dim(), 2);
173+
ASSERT_EQ(attributes.find("linear.bias")->second.sizes()[0], 3);
174+
ASSERT_EQ(attributes.find("linear.bias")->second.dim(), 1);
155175
}

extension/training/module/training_module.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,41 @@ TrainingModule::named_gradients(const std::string& method_name) {
154154
return method_named_gradients_.at(method_name);
155155
}
156156

157+
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
158+
TrainingModule::named_attributes(const std::string& method_name) {
159+
// If we haven't seen this method before, populate the dict.
160+
if (method_named_attributes_.find(method_name) ==
161+
method_named_attributes_.end()) {
162+
method_named_attributes_.insert({method_name, {}});
163+
164+
// get method metadata
165+
auto meta_res = executorch::extension::Module::method_meta(method_name);
166+
if (!meta_res.ok()) {
167+
return meta_res.error();
168+
}
169+
// get method
170+
auto method_res = executorch::extension::Module::method(method_name);
171+
if (!method_res.ok()) {
172+
return method_res.error();
173+
}
174+
// get tensor by name
175+
for (int idx = 0; idx < meta_res->num_attributes(); idx++) {
176+
const auto tensor_res = meta_res->attribute_tensor_meta(idx);
177+
if (!tensor_res.ok()) {
178+
return tensor_res.error();
179+
}
180+
const auto tensorName = tensor_res.get().name();
181+
const auto attribute_res = (*method_res)->get_attribute(tensorName);
182+
if (!attribute_res.ok()) {
183+
return attribute_res.error();
184+
}
185+
method_named_attributes_.at(method_name)
186+
.insert({tensorName, attribute_res.get()});
187+
}
188+
}
189+
return method_named_attributes_.at(method_name);
190+
}
191+
157192
} // namespace training
158193
} // namespace extension
159194
} // namespace executorch

extension/training/module/training_module.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ class ET_EXPERIMENTAL TrainingModule final
9393
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
9494
named_gradients(const std::string& method_name);
9595

96+
/**
97+
* Retrieve the attributes for a method.
98+
*
99+
* @param[in] method_name The name of the method to get the
100+
* attributes for.
101+
*
102+
* @returns A Result object containing a map of the fully qualified name to
103+
* attribute tensor.
104+
*/
105+
ET_EXPERIMENTAL
106+
runtime::Result<const std::map<std::string_view, executorch::aten::Tensor>>
107+
named_attributes(const std::string& method_name);
108+
96109
private:
97110
std::unordered_map<
98111
std::string,
@@ -103,6 +116,11 @@ class ET_EXPERIMENTAL TrainingModule final
103116
std::string,
104117
std::map<std::string_view, executorch::aten::Tensor>>
105118
method_named_parameters_;
119+
120+
std::unordered_map<
121+
std::string,
122+
std::map<std::string_view, executorch::aten::Tensor>>
123+
method_named_attributes_;
106124
};
107125

108126
} // namespace training

0 commit comments

Comments
 (0)