@@ -26,11 +26,15 @@ class ModuleTest : public ::testing::Test {
26
26
model_path_ = std::getenv (" ET_MODULE_ADD_PATH" );
27
27
add_mul_path_ = std::getenv (" ET_MODULE_ADD_MUL_PROGRAM_PATH" );
28
28
add_mul_data_path_ = std::getenv (" ET_MODULE_ADD_MUL_DATA_PATH" );
29
+ linear_path_ = std::getenv (" ET_MODULE_LINEAR_PROGRAM_PATH" );
30
+ linear_data_path_ = std::getenv (" ET_MODULE_LINEAR_DATA_PATH" );
29
31
}
30
32
31
33
static inline std::string model_path_;
32
34
static inline std::string add_mul_path_;
33
35
static inline std::string add_mul_data_path_;
36
+ static inline std::string linear_path_;
37
+ static inline std::string linear_data_path_;
34
38
};
35
39
36
40
TEST_F (ModuleTest, TestLoad) {
@@ -532,16 +536,21 @@ TEST_F(ModuleTest, TestPTD) {
532
536
}
533
537
534
538
TEST_F (ModuleTest, TestPTD_Multiple) {
535
- std::vector<std::string> data_files = {add_mul_data_path_};
536
- Module module (add_mul_path_, data_files);
537
-
538
- ASSERT_EQ (module .load_method (" forward" ), Error::Ok);
539
+ std::vector<std::string> data_files = {add_mul_data_path_, linear_data_path_};
539
540
541
+ // Create module with add mul.
542
+ Module module_add_mul (add_mul_path_, data_files);
543
+ ASSERT_EQ (module_add_mul.load_method (" forward" ), Error::Ok);
540
544
auto tensor = make_tensor_ptr ({2 , 2 }, {2 .f , 3 .f , 4 .f , 2 .f });
541
- ASSERT_EQ (module .forward (tensor).error (), Error::Ok);
545
+ ASSERT_EQ (module_add_mul .forward (tensor).error (), Error::Ok);
542
546
543
547
// Confirm that the data_file is not std::move'd away.
544
548
ASSERT_EQ (std::strcmp (data_files[0 ].c_str (), add_mul_data_path_.c_str ()), 0 );
549
+ ASSERT_EQ (std::strcmp (data_files[1 ].c_str (), linear_data_path_.c_str ()), 0 );
545
550
546
- // TODO(lfq): add test when merge capability is supported.
551
+ // Create module with linear.
552
+ Module module_linear (linear_path_, data_files);
553
+ ASSERT_EQ (module_linear.load_method (" forward" ), Error::Ok);
554
+ auto tensor2 = make_tensor_ptr ({3 }, {2 .f , 3 .f , 4 .f });
555
+ ASSERT_EQ (module_linear.forward (tensor2).error (), Error::Ok);
547
556
}
0 commit comments