@@ -26,26 +26,35 @@ using torch::executor::util::FileDataLoader;
2626
2727class MethodMetaTest : public ::testing::Test {
2828 protected:
29- void SetUp () override {
30- // Create a loader for the serialized ModuleAdd program.
31- const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
29+ void load_program (const char * path, const char * module_name) {
30+ // Create a loader for the serialized program.
3231 Result<FileDataLoader> loader = FileDataLoader::from (path);
3332 ASSERT_EQ (loader.error (), Error::Ok);
34- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
33+ loaders_.insert (
34+ {module_name,
35+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
3536
3637 // Use it to load the program.
3738 Result<Program> program = Program::load (
38- loader_.get (), Program::Verification::InternalConsistency);
39+ loaders_[module_name].get (),
40+ Program::Verification::InternalConsistency);
3941 ASSERT_EQ (program.error (), Error::Ok);
40- program_ = std::make_unique<Program>(std::move (program.get ()));
42+ programs_.insert (
43+ {module_name, std::make_unique<Program>(std::move (program.get ()))});
44+ }
45+
46+ void SetUp () override {
47+
48+ load_program (std::getenv (" ET_MODULE_ADD_PATH" ), " add" );
49+ load_program (std::getenv (" ET_MODULE_STATEFUL_PATH" ), " stateful" );
4150 }
4251
4352 private:
4453 // Must outlive program_, but tests shouldn't need to touch it.
45- std::unique_ptr<FileDataLoader> loader_ ;
54+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
4655
4756 protected:
48- std::unique_ptr<Program> program_ ;
57+ std::unordered_map<std::string, std:: unique_ptr<Program>> programs_ ;
4958};
5059
5160namespace {
@@ -67,7 +76,7 @@ void check_tensor(const TensorInfo& tensor_info) {
6776} // namespace
6877
6978TEST_F (MethodMetaTest, MethodMetaApi) {
70- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
79+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
7180 ASSERT_EQ (method_meta.error (), Error::Ok);
7281
7382 // Appropriate amount of inputs
@@ -97,11 +106,11 @@ TEST_F(MethodMetaTest, MethodMetaApi) {
97106
98107 // Missing method fails
99108 EXPECT_EQ (
100- program_ ->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
109+ programs_[ " add " ] ->method_meta (" not_a_method" ).error (), Error::InvalidArgument);
101110}
102111
103112TEST_F (MethodMetaTest, TensorInfoApi) {
104- Result<MethodMeta> method_meta = program_ ->method_meta (" forward" );
113+ Result<MethodMeta> method_meta = programs_[ " add " ] ->method_meta (" forward" );
105114 ASSERT_EQ (method_meta.error (), Error::Ok);
106115
107116 // Input 1
@@ -138,3 +147,18 @@ TEST_F(MethodMetaTest, TensorInfoApi) {
138147 EXPECT_EQ (
139148 method_meta->output_tensor_meta (-1 ).error (), Error::InvalidArgument);
140149}
150+
151+ TEST_F (MethodMetaTest, MethodMetaAttribute) {
152+ Result<MethodMeta> method_meta = programs_[" stateful" ]->method_meta (" forward" );
153+ ASSERT_EQ (method_meta.error (), Error::Ok);
154+
155+ ASSERT_EQ (method_meta->num_attributes (), 1 );
156+ auto state = method_meta->attribute_tensor_meta (0 );
157+ ASSERT_TRUE (state.ok ());
158+
159+ ASSERT_EQ (state->name (), " state" );
160+ ASSERT_FALSE (state->is_memory_planned ());
161+
162+ auto bad_access = method_meta->attribute_tensor_meta (1 );
163+ ASSERT_EQ (bad_access.error (), Error::InvalidArgument);
164+ }
0 commit comments