Skip to content

Commit dc1206d

Browse files
authored
Access Method directly from TrainingModule. (#13602)
Summary: . Differential Revision: D80821085
1 parent 2603f74 commit dc1206d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

extension/training/module/training_module.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,24 @@ TrainingModule::named_attributes(const std::string& method_name) {
162162
method_named_attributes_.insert({method_name, {}});
163163

164164
// get method metadata
165-
auto meta_res = executorch::extension::Module::method_meta(method_name);
165+
auto meta_res = method_meta(method_name);
166166
if (!meta_res.ok()) {
167167
return meta_res.error();
168168
}
169169
// get method
170-
auto method_res = executorch::extension::Module::method(method_name);
171-
if (!method_res.ok()) {
172-
return method_res.error();
170+
auto e = load_method(method_name);
171+
if (e != runtime::Error::Ok) {
172+
return e;
173173
}
174+
auto& method = methods_.at(method_name).method;
174175
// get tensor by name
175176
for (int idx = 0; idx < meta_res->num_attributes(); idx++) {
176177
const auto tensor_res = meta_res->attribute_tensor_meta(idx);
177178
if (!tensor_res.ok()) {
178179
return tensor_res.error();
179180
}
180181
const auto tensorName = tensor_res.get().name();
181-
const auto attribute_res = (*method_res)->get_attribute(tensorName);
182+
const auto attribute_res = method->get_attribute(tensorName);
182183
if (!attribute_res.ok()) {
183184
return attribute_res.error();
184185
}

0 commit comments

Comments
 (0)