@@ -99,7 +99,7 @@ + (void)initialize {
99
99
IMP testForwardImplementation = imp_implementationWithBlock (^(
100
100
id _self) {
101
101
auto __block module = std::make_unique<Module>(modelPath.UTF8String );
102
- XCTAssertEqual (module ->load_method ( " forward " ), Error::Ok);
102
+ XCTAssertEqual (module ->load_forward ( ), Error::Ok);
103
103
104
104
const auto method_meta = module ->method_meta (" forward" );
105
105
XCTAssertEqual (method_meta.error (), Error::Ok);
@@ -109,8 +109,6 @@ + (void)initialize {
109
109
110
110
std::vector<TensorPtr> __block tensors;
111
111
tensors.reserve (num_inputs);
112
- std::vector<EValue> __block inputs;
113
- inputs.reserve (num_inputs);
114
112
115
113
for (auto index = 0 ; index < num_inputs; ++index) {
116
114
const auto input_tag = method_meta->input_tag (index);
@@ -124,7 +122,7 @@ + (void)initialize {
124
122
const auto sizes = tensor_meta->sizes ();
125
123
tensors.emplace_back (ones ({sizes.begin (), sizes.end ()},
126
124
tensor_meta->scalar_type ()));
127
- inputs. emplace_back (tensors.back ());
125
+ module . set_input (tensors.back (), index );
128
126
} break ;
129
127
default :
130
128
XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
0 commit comments