@@ -99,7 +99,7 @@ + (void)initialize {
9999 IMP testForwardImplementation = imp_implementationWithBlock (^(
100100 id _self) {
101101 auto __block module = std::make_unique<Module>(modelPath.UTF8String );
102- XCTAssertEqual (module ->load_method ( " forward " ), Error::Ok);
102+ XCTAssertEqual (module ->load_forward ( ), Error::Ok);
103103
104104 const auto method_meta = module ->method_meta (" forward" );
105105 XCTAssertEqual (method_meta.error (), Error::Ok);
@@ -109,8 +109,6 @@ + (void)initialize {
109109
110110 std::vector<TensorPtr> __block tensors;
111111 tensors.reserve (num_inputs);
112- std::vector<EValue> __block inputs;
113- inputs.reserve (num_inputs);
114112
115113 for (auto index = 0 ; index < num_inputs; ++index) {
116114 const auto input_tag = method_meta->input_tag (index);
@@ -124,7 +122,7 @@ + (void)initialize {
124122 const auto sizes = tensor_meta->sizes ();
125123 tensors.emplace_back (ones ({sizes.begin (), sizes.end ()},
126124 tensor_meta->scalar_type ()));
127- inputs. emplace_back (tensors.back ());
125+ module . set_input (tensors.back (), index );
128126 } break ;
129127 default :
130128 XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
0 commit comments