@@ -26,26 +26,34 @@ using torch::executor::util::FileDataLoader;
2626
2727class MethodMetaTest : public ::testing::Test {
2828 protected:
29- void SetUp () override {
30- // Create a loader for the serialized ModuleAdd program.
31- const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
29+ void load_program (const char * path, const char * module_name) {
30+ // Create a loader for the serialized program.
3231 Result<FileDataLoader> loader = FileDataLoader::from (path);
3332 ASSERT_EQ (loader.error (), Error::Ok);
34- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
33+ loaders_.insert (
34+ {module_name,
35+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
3536
3637 // Use it to load the program.
3738 Result<Program> program = Program::load (
38- loader_.get (), Program::Verification::InternalConsistency);
39+ loaders_[module_name].get (),
40+ Program::Verification::InternalConsistency);
3941 ASSERT_EQ (program.error (), Error::Ok);
40- program_ = std::make_unique<Program>(std::move (program.get ()));
42+ programs_.insert (
43+ {module_name, std::make_unique<Program>(std::move (program.get ()))});
44+ }
45+
46+ void SetUp () override {
47+ load_program (std::getenv (" ET_MODULE_ADD_PATH" ), " add" );
48+ load_program (std::getenv (" ET_MODULE_STATEFUL_PATH" ), " stateful" );
4149 }
4250
4351 private:
4452 // Must outlive program_, but tests shouldn't need to touch it.
45- std::unique_ptr<FileDataLoader> loader_ ;
53+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
4654
4755 protected:
48- std::unique_ptr<Program> program_ ;
56+ std::unordered_map<std::string, std:: unique_ptr<Program>> programs_ ;
4957};
5058
5159namespace {
@@ -67,7 +75,7 @@ void check_tensor(const TensorInfo& tensor_info) {
6775} // namespace
6876
6977TEST_F (MethodMetaTest, MethodMetaApi) {
70- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
78+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
7179 ASSERT_EQ (method_meta.error (), Error::Ok);
7280
7381 // Appropriate amount of inputs
@@ -97,11 +105,12 @@ TEST_F(MethodMetaTest, MethodMetaApi) {
97105
98106 // Missing method fails
99107 EXPECT_EQ (
100- program_->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
108+ programs_[" add" ]->method_meta (" not_a_method" ).error (),
109+ Error::InvalidArgument);
101110}
102111
103112TEST_F (MethodMetaTest, TensorInfoApi) {
104- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
113+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
105114 ASSERT_EQ (method_meta.error (), Error::Ok);
106115
107116 // Input 1
@@ -138,3 +147,19 @@ TEST_F(MethodMetaTest, TensorInfoApi) {
138147 EXPECT_EQ (
139148 method_meta->output_tensor_meta (-1 ).error (), Error::InvalidArgument);
140149}
150+
151+ TEST_F (MethodMetaTest, MethodMetaAttribute) {
152+ Result<MethodMeta> method_meta =
153+ programs_[" stateful" ]->method_meta (" forward" );
154+ ASSERT_EQ (method_meta.error (), Error::Ok);
155+
156+ ASSERT_EQ (method_meta->num_attributes (), 1 );
157+ auto state = method_meta->attribute_tensor_meta (0 );
158+ ASSERT_TRUE (state.ok ());
159+
160+ ASSERT_EQ (state->name (), " state" );
161+ ASSERT_FALSE (state->is_memory_planned ());
162+
163+ auto bad_access = method_meta->attribute_tensor_meta (1 );
164+ ASSERT_EQ (bad_access.error (), Error::InvalidArgument);
165+ }
0 commit comments