77 */
88
99#import " ResourceTestCase.h"
10-
11- #import < executorch/extension/module/module.h>
12- #import < executorch/extension/tensor/tensor.h>
13-
14- using namespace ::executorch::extension;
15- using namespace ::executorch::runtime;
10+ #import " test_function.h"
1611
1712#define ASSERT_OK_OR_RETURN (value__ ) \
1813 ({ \
@@ -37,7 +32,7 @@ @implementation GenericTests
3732+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
3833 return @{
3934 @" model" : ^BOOL (NSString *filename){
40- return [filename hasSuffix: @" .pte " ];
35+ return [filename hasSuffix: @" .mlpackage " ];
4136 },
4237 };
4338}
@@ -50,46 +45,9 @@ @implementation GenericTests
5045 [testCase
5146 measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
5247 block: ^{
53- XCTAssertEqual (
54- Module (modelPath.UTF8String ).load_forward (),
55- Error::Ok);
48+ XCTAssertEqual (prefill_no_kv_cache (), false );
5649 }];
57- },
58- @" forward" : ^(XCTestCase *testCase) {
59- auto __block module = std::make_unique<Module>(modelPath.UTF8String );
60-
61- const auto method_meta = module ->method_meta (" forward" );
62- ASSERT_OK_OR_RETURN (method_meta);
63-
64- const auto num_inputs = method_meta->num_inputs ();
65- XCTAssertGreaterThan (num_inputs, 0 );
66-
67- std::vector<TensorPtr> tensors;
68- tensors.reserve (num_inputs);
69-
70- for (auto index = 0 ; index < num_inputs; ++index) {
71- const auto input_tag = method_meta->input_tag (index);
72- ASSERT_OK_OR_RETURN (input_tag);
73-
74- switch (*input_tag) {
75- case Tag::Tensor: {
76- const auto tensor_meta = method_meta->input_tensor_meta (index);
77- ASSERT_OK_OR_RETURN (tensor_meta);
78-
79- const auto sizes = tensor_meta->sizes ();
80- tensors.emplace_back (
81- ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
82- XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
83- } break ;
84- default :
85- XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
86- }
87- }
88- [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
89- block: ^{
90- XCTAssertEqual (module ->forward ().error (), Error::Ok);
91- }];
92- },
50+ }
9351 };
9452}
9553
0 commit comments