diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 43b3cd0f9b8..ad0859ab7e6 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -231,11 +231,6 @@ runtime::Error Module::load_method( ET_NODISCARD runtime::Result Module::method( const std::string& method_name) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); - ET_CHECK_OR_RETURN_ERROR( - methods_.count(method_name) > 0, - InvalidArgument, - "no such method in program: %s", - method_name.c_str()); return methods_[method_name].method.get(); } diff --git a/extension/training/module/test/training_module_test.cpp b/extension/training/module/test/training_module_test.cpp index 3ba46c6f653..16ff87bc022 100644 --- a/extension/training/module/test/training_module_test.cpp +++ b/extension/training/module/test/training_module_test.cpp @@ -88,6 +88,12 @@ TEST_F(TrainingModuleTest, JointGraphTest) { ASSERT_EQ(param.find("linear.weight")->second.dim(), 2); ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3); ASSERT_EQ(param.find("linear.bias")->second.dim(), 1); + + // Test attributes for pte only model + auto attributes_res = mod.named_attributes("forward"); + ASSERT_EQ(attributes_res.error(), Error::Ok); + auto& attributes = attributes_res.get(); + ASSERT_EQ(attributes.size(), 0); } TEST_F(TrainingModuleTest, NonTrainingModuleTest) { @@ -153,3 +159,43 @@ TEST_F(TrainingModuleTest, SeperateDataTest) { ASSERT_EQ(res.error(), Error::Ok); ASSERT_EQ(res.get().size(), 1); } + +TEST_F(TrainingModuleTest, DataExternalConstantsTest) { + // Test the external constants are loaded correctly. + const char* ptd_path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); + Result data_map_loader_res = FileDataLoader::from(ptd_path); + ASSERT_EQ(data_map_loader_res.error(), Error::Ok); + + auto data_map_loader = + std::make_unique( + std::move(data_map_loader_res.get())); + + const char* pte_path = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH"); + Result pte_loader_res = FileDataLoader::from(pte_path); + ASSERT_EQ(pte_loader_res.error(), Error::Ok); + + auto pte_loader = std::make_unique( + std::move(pte_loader_res.get())); + + auto mod = executorch::extension::training::TrainingModule( + std::move(pte_loader), + nullptr, + nullptr, + nullptr, + std::move(data_map_loader)); + + // Test Attributes for pte + ptd model containing external constants + auto attributes_res = mod.named_attributes("forward"); + ASSERT_EQ(attributes_res.error(), Error::Ok); + auto& attributes = attributes_res.get(); + ASSERT_EQ(attributes.size(), 2); + ASSERT_NE(attributes.find("a"), attributes.end()); + ASSERT_NE(attributes.find("b"), attributes.end()); + + ASSERT_EQ(attributes.find("a")->second.sizes()[0], 2); + ASSERT_EQ(attributes.find("a")->second.sizes()[1], 2); + ASSERT_EQ(attributes.find("a")->second.dim(), 2); + ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2); + ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2); + ASSERT_EQ(attributes.find("b")->second.dim(), 2); +} diff --git a/extension/training/module/training_module.cpp b/extension/training/module/training_module.cpp index 4dbaaf3fcfb..57514355f5e 100644 --- a/extension/training/module/training_module.cpp +++ b/extension/training/module/training_module.cpp @@ -154,6 +154,41 @@ TrainingModule::named_gradients(const std::string& method_name) { return method_named_gradients_.at(method_name); } +runtime::Result> +TrainingModule::named_attributes(const std::string& method_name) { + // If we haven't seen this method before, populate the dict. + if (method_named_attributes_.find(method_name) == + method_named_attributes_.end()) { + method_named_attributes_.insert({method_name, {}}); + + // get method metadata + auto meta_res = executorch::extension::Module::method_meta(method_name); + if (!meta_res.ok()) { + return meta_res.error(); + } + // get method + auto method_res = executorch::extension::Module::method(method_name); + if (!method_res.ok()) { + return method_res.error(); + } + // get tensor by name + for (int idx = 0; idx < meta_res->num_attributes(); idx++) { + const auto tensor_res = meta_res->attribute_tensor_meta(idx); + if (!tensor_res.ok()) { + return tensor_res.error(); + } + const auto tensorName = tensor_res.get().name(); + const auto attribute_res = (*method_res)->get_attribute(tensorName); + if (!attribute_res.ok()) { + return attribute_res.error(); + } + method_named_attributes_.at(method_name) + .insert({tensorName, attribute_res.get()}); + } + } + return method_named_attributes_.at(method_name); +} + } // namespace training } // namespace extension } // namespace executorch diff --git a/extension/training/module/training_module.h b/extension/training/module/training_module.h index d4050bea827..7dd380d2709 100644 --- a/extension/training/module/training_module.h +++ b/extension/training/module/training_module.h @@ -93,6 +93,19 @@ class ET_EXPERIMENTAL TrainingModule final runtime::Result> named_gradients(const std::string& method_name); + /** + * Retrieve the attributes for a method. + * + * @param[in] method_name The name of the method to get the + * attributes for. + * + * @returns A Result object containing a map of the fully qualified name to + * attribute tensor. + */ + ET_EXPERIMENTAL + runtime::Result> + named_attributes(const std::string& method_name); + private: std::unordered_map< std::string, @@ -103,6 +116,11 @@ class ET_EXPERIMENTAL TrainingModule final std::string, std::map> method_named_parameters_; + + std::unordered_map< + std::string, + std::map> + method_named_attributes_; }; } // namespace training