@@ -35,62 +35,62 @@ @implementation GenericTests
3535}
3636
3737+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
38- return @{
39- @" model" : ^BOOL (NSString *filename){
38+ return @{@" model" : ^BOOL (NSString *filename){
4039 return [filename hasSuffix: @" .pte" ];
41- }
42- };
40+ }
41+ }
42+ ;
4343}
4444
4545+ (NSDictionary <NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources :
4646 (NSDictionary <NSString *, NSString *> *)resources {
4747 NSString *modelPath = resources[@" model" ];
48- return @{
49- @" load" : ^(XCTestCase *testCase){
48+ return @{@" load" : ^(XCTestCase *testCase){
5049 [testCase
5150 measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
5251 block: ^{
5352 XCTAssertEqual (
5453 Module (modelPath.UTF8String ).load_forward (),
5554 Error::Ok);
5655 }];
57- },
58- @" forward" : ^(XCTestCase *testCase) {
59- auto __block module = std::make_unique<Module>(modelPath.UTF8String );
60-
61- const auto method_meta = module ->method_meta (" forward" );
62- ASSERT_OK_OR_RETURN (method_meta);
63-
64- const auto num_inputs = method_meta->num_inputs ();
65- XCTAssertGreaterThan (num_inputs, 0 );
66-
67- std::vector<TensorPtr> tensors;
68- tensors.reserve (num_inputs);
69-
70- for (auto index = 0 ; index < num_inputs; ++index) {
71- const auto input_tag = method_meta->input_tag (index);
72- ASSERT_OK_OR_RETURN (input_tag);
73-
74- switch (*input_tag) {
75- case Tag::Tensor: {
76- const auto tensor_meta = method_meta->input_tensor_meta (index);
77- ASSERT_OK_OR_RETURN (tensor_meta);
78-
79- const auto sizes = tensor_meta->sizes ();
80- tensors.emplace_back (
81- ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
82- XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
83- } break ;
84- default :
85- XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
86- }
87- }
88- [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
89- block: ^{
90- XCTAssertEqual (module ->forward ().error (), Error::Ok);
91- }];
56+ }
57+ , @" forward" : ^(XCTestCase *testCase) {
58+ auto __block module = std::make_unique<Module>(modelPath.UTF8String );
59+
60+ const auto method_meta = module ->method_meta (" forward" );
61+ ASSERT_OK_OR_RETURN (method_meta);
62+
63+ const auto num_inputs = method_meta->num_inputs ();
64+ XCTAssertGreaterThan (num_inputs, 0 );
65+
66+ std::vector<TensorPtr> tensors;
67+ tensors.reserve (num_inputs);
68+
69+ for (auto index = 0 ; index < num_inputs; ++index) {
70+ const auto input_tag = method_meta->input_tag (index);
71+ ASSERT_OK_OR_RETURN (input_tag);
72+
73+ switch (*input_tag) {
74+ case Tag::Tensor: {
75+ const auto tensor_meta = method_meta->input_tensor_meta (index);
76+ ASSERT_OK_OR_RETURN (tensor_meta);
77+
78+ const auto sizes = tensor_meta->sizes ();
79+ tensors.emplace_back (
80+ ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
81+ XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
82+ } break ;
83+ default :
84+ XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
9285 }
93- };
86+ }
87+ [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
88+ block: ^{
89+ XCTAssertEqual (module ->forward ().error (), Error::Ok);
90+ }];
91+ }
92+ }
93+ ;
9494}
9595
9696@end
0 commit comments