@@ -35,62 +35,62 @@ @implementation GenericTests
3535}
3636
3737+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
38- return @{@" model" : ^BOOL (NSString *filename){
38+ return @{
39+ @" model" : ^BOOL (NSString *filename){
3940 return [filename hasSuffix: @" .pte" ];
40- }
41- }
42- ;
41+ },
42+ };
4343}
4444
4545+ (NSDictionary <NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources :
4646 (NSDictionary <NSString *, NSString *> *)resources {
4747 NSString *modelPath = resources[@" model" ];
48- return @{@" load" : ^(XCTestCase *testCase){
48+ return @{
49+ @" load" : ^(XCTestCase *testCase){
4950 [testCase
5051 measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
5152 block: ^{
5253 XCTAssertEqual (
5354 Module (modelPath.UTF8String ).load_forward (),
5455 Error::Ok);
5556 }];
56- }
57- , @" forward" : ^(XCTestCase *testCase) {
58- auto __block module = std::make_unique<Module>(modelPath.UTF8String );
59-
60- const auto method_meta = module ->method_meta (" forward" );
61- ASSERT_OK_OR_RETURN (method_meta);
62-
63- const auto num_inputs = method_meta->num_inputs ();
64- XCTAssertGreaterThan (num_inputs, 0 );
65-
66- std::vector<TensorPtr> tensors;
67- tensors.reserve (num_inputs);
68-
69- for (auto index = 0 ; index < num_inputs; ++index) {
70- const auto input_tag = method_meta->input_tag (index);
71- ASSERT_OK_OR_RETURN (input_tag);
72-
73- switch (*input_tag) {
74- case Tag::Tensor: {
75- const auto tensor_meta = method_meta->input_tensor_meta (index);
76- ASSERT_OK_OR_RETURN (tensor_meta);
77-
78- const auto sizes = tensor_meta->sizes ();
79- tensors.emplace_back (
80- ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
81- XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
82- } break ;
83- default :
84- XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
85- }
86- }
87- [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
88- block: ^{
89- XCTAssertEqual (module ->forward ().error (), Error::Ok);
90- }];
91- }
92- }
93- ;
57+ },
58+ @" forward" : ^(XCTestCase *testCase) {
59+ auto __block module = std::make_unique<Module>(modelPath.UTF8String );
60+
61+ const auto method_meta = module ->method_meta (" forward" );
62+ ASSERT_OK_OR_RETURN (method_meta);
63+
64+ const auto num_inputs = method_meta->num_inputs ();
65+ XCTAssertGreaterThan (num_inputs, 0 );
66+
67+ std::vector<TensorPtr> tensors;
68+ tensors.reserve (num_inputs);
69+
70+ for (auto index = 0 ; index < num_inputs; ++index) {
71+ const auto input_tag = method_meta->input_tag (index);
72+ ASSERT_OK_OR_RETURN (input_tag);
73+
74+ switch (*input_tag) {
75+ case Tag::Tensor: {
76+ const auto tensor_meta = method_meta->input_tensor_meta (index);
77+ ASSERT_OK_OR_RETURN (tensor_meta);
78+
79+ const auto sizes = tensor_meta->sizes ();
80+ tensors.emplace_back (
81+ ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
82+ XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
83+ } break ;
84+ default :
85+ XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
86+ }
87+ }
88+ [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
89+ block: ^{
90+ XCTAssertEqual (module ->forward ().error (), Error::Ok);
91+ }];
92+ },
93+ };
9494}
9595
9696@end
0 commit comments