Skip to content

Commit e859623

Browse files
committed
init
1 parent 27330f2 commit e859623

File tree

2 files changed

+34
-46
lines changed

2 files changed

+34
-46
lines changed

extension/benchmark/apple/Benchmark/Tests/GenericTests.mm

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
*/
88

99
#import "ResourceTestCase.h"
10-
11-
#import <executorch/extension/module/module.h>
12-
#import <executorch/extension/tensor/tensor.h>
13-
14-
using namespace ::executorch::extension;
15-
using namespace ::executorch::runtime;
10+
#import "test_function.h"
1611

1712
#define ASSERT_OK_OR_RETURN(value__) \
1813
({ \
@@ -37,7 +32,7 @@ @implementation GenericTests
3732
+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
3833
return @{
3934
@"model" : ^BOOL(NSString *filename){
40-
return [filename hasSuffix:@".pte"];
35+
return [filename hasSuffix:@".mlpackage"];
4136
},
4237
};
4338
}
@@ -50,46 +45,9 @@ @implementation GenericTests
5045
[testCase
5146
measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ]
5247
block:^{
53-
XCTAssertEqual(
54-
Module(modelPath.UTF8String).load_forward(),
55-
Error::Ok);
48+
XCTAssertEqual(prefill_no_kv_cache(), false);
5649
}];
57-
},
58-
@"forward" : ^(XCTestCase *testCase) {
59-
auto __block module = std::make_unique<Module>(modelPath.UTF8String);
60-
61-
const auto method_meta = module->method_meta("forward");
62-
ASSERT_OK_OR_RETURN(method_meta);
63-
64-
const auto num_inputs = method_meta->num_inputs();
65-
XCTAssertGreaterThan(num_inputs, 0);
66-
67-
std::vector<TensorPtr> tensors;
68-
tensors.reserve(num_inputs);
69-
70-
for (auto index = 0; index < num_inputs; ++index) {
71-
const auto input_tag = method_meta->input_tag(index);
72-
ASSERT_OK_OR_RETURN(input_tag);
73-
74-
switch (*input_tag) {
75-
case Tag::Tensor: {
76-
const auto tensor_meta = method_meta->input_tensor_meta(index);
77-
ASSERT_OK_OR_RETURN(tensor_meta);
78-
79-
const auto sizes = tensor_meta->sizes();
80-
tensors.emplace_back(
81-
ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type()));
82-
XCTAssertEqual(module->set_input(tensors.back(), index), Error::Ok);
83-
} break;
84-
default:
85-
XCTFail("Unsupported tag %i at input %d", *input_tag, index);
86-
}
87-
}
88-
[testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ]
89-
block:^{
90-
XCTAssertEqual(module->forward().error(), Error::Ok);
91-
}];
92-
},
50+
}
9351
};
9452
}
9553

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
10+
#import <CoreML/CoreML.h>
11+
#import "model_no_kv.h"
12+
13+
BOOL prefill_no_kv_cache(void) {
14+
model_no_kv *model = [[model_no_kv alloc] init];
15+
16+
MLMultiArray *tokens = [[MLMultiArray alloc] initWithShape:(NSArray*)(@[@1, @512])
17+
dataType:(MLMultiArrayDataType)MLMultiArrayDataTypeInt32
18+
error:nil] ;
19+
20+
for (int i = 0; i < 512; i++) {
21+
tokens[i] = @2;
22+
}
23+
model_no_kvInput *inputs = [[model_no_kvInput alloc] initWithTokens:tokens];
24+
25+
for (int i = 0; i < 100; i++) {
26+
model_no_kvOutput *output = [model predictionFromFeatures:inputs error:nil];
27+
}
28+
29+
return YES;
30+
}

0 commit comments

Comments
 (0)