Skip to content

Commit e4ac33d

Browse files
committed
coreml test
1 parent 7cf264a commit e4ac33d

File tree

1 file changed

+112
-70
lines changed

1 file changed

+112
-70
lines changed

extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm

Lines changed: 112 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,42 @@
1111
#import <CoreML/CoreML.h>
1212

1313
static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) {
14-
MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape
15-
dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16-
error:error];
17-
for (auto index = 0; index < array.count; ++index) {
18-
array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0;
19-
}
20-
return array;
14+
MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape
15+
dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16+
error:error];
17+
for (auto index = 0; index < array.count; ++index) {
18+
array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0;
19+
}
20+
return array;
2121
}
2222

2323
static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) {
24-
NSMutableDictionary *inputs = [NSMutableDictionary dictionary];
25-
NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName;
26-
27-
for (NSString *inputName in inputDescriptions) {
28-
MLFeatureDescription *feature = inputDescriptions[inputName];
29-
30-
switch (feature.type) {
31-
case MLFeatureTypeMultiArray: {
32-
MLMultiArray *array = DummyMultiArrayForFeature(feature, error);
33-
inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array];
34-
break;
35-
}
36-
case MLFeatureTypeInt64:
37-
inputs[inputName] = [MLFeatureValue featureValueWithInt64:1];
38-
break;
39-
case MLFeatureTypeDouble:
40-
inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0];
41-
break;
42-
case MLFeatureTypeString:
43-
inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"];
44-
break;
45-
default:
46-
break;
24+
NSMutableDictionary *inputs = [NSMutableDictionary dictionary];
25+
NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName;
26+
27+
for (NSString *inputName in inputDescriptions) {
28+
MLFeatureDescription *feature = inputDescriptions[inputName];
29+
30+
switch (feature.type) {
31+
case MLFeatureTypeMultiArray: {
32+
MLMultiArray *array = DummyMultiArrayForFeature(feature, error);
33+
inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array];
34+
break;
35+
}
36+
case MLFeatureTypeInt64:
37+
inputs[inputName] = [MLFeatureValue featureValueWithInt64:1];
38+
break;
39+
case MLFeatureTypeDouble:
40+
inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0];
41+
break;
42+
case MLFeatureTypeString:
43+
inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"];
44+
break;
45+
default:
46+
break;
47+
}
4748
}
48-
}
49-
return inputs;
49+
return inputs;
5050
}
5151

5252
@interface CoreMLTests : ResourceTestCase
@@ -55,51 +55,93 @@ @interface CoreMLTests : ResourceTestCase
5555
@implementation CoreMLTests
5656

5757
+ (NSArray<NSString *> *)directories {
58-
return @[@"Resources"];
58+
return @[@"Resources"];
5959
}
6060

6161
+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
62-
return @{ @"model" : ^BOOL(NSString *filename) {
63-
return [filename hasSuffix:@".mlpackage"];
64-
}};
62+
return @{ @"model" : ^BOOL(NSString *filename) {
63+
return [filename hasSuffix:@"combined.mlpackage"];
64+
}};
6565
}
6666

6767
+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:(NSDictionary<NSString *, NSString *> *)resources {
68-
NSString *modelPath = resources[@"model"];
69-
70-
return @{
71-
@"prediction" : ^(XCTestCase *testCase) {
72-
NSError *error = nil;
73-
NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error];
74-
if (error || !compiledModelURL) {
75-
XCTFail(@"Failed to compile model: %@", error.localizedDescription);
76-
return;
77-
}
78-
MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error];
79-
if (error || !model) {
80-
XCTFail(@"Failed to load model: %@", error.localizedDescription);
81-
return;
82-
}
83-
NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
84-
if (error || !inputs) {
85-
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
86-
return;
87-
}
88-
MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
89-
if (error || !featureProvider) {
90-
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
91-
return;
92-
}
93-
[testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]]
94-
block:^{
95-
NSError *error = nil;
96-
id<MLFeatureProvider> prediction = [model predictionFromFeatures:featureProvider error:&error];
97-
if (error || !prediction) {
98-
XCTFail(@"Prediction failed: %@", error.localizedDescription);
68+
NSString *modelPath = resources[@"model"];
69+
70+
return @{
71+
@"prediction" : ^(XCTestCase *testCase) {
72+
73+
NSError *error = nil;
74+
NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error];
75+
if (error || !compiledModelURL) {
76+
XCTFail(@"Failed to compile model: %@", error.localizedDescription);
77+
return;
78+
}
79+
80+
// Model1
81+
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
82+
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
83+
config.functionName = @"model1";
84+
MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL configuration:config error:&error];
85+
if (error || !model) {
86+
XCTFail(@"Failed to load model: %@", error.localizedDescription);
87+
return;
88+
}
89+
NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
90+
if (error || !inputs) {
91+
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
92+
return;
93+
}
94+
MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
95+
if (error || !featureProvider) {
96+
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
97+
return;
98+
}
99+
100+
101+
// Model2
102+
MLModelConfiguration *config2 = [[MLModelConfiguration alloc] init];
103+
config2.computeUnits = MLComputeUnitsCPUOnly;
104+
config2.functionName = @"model2";
105+
MLModel *model2 = [MLModel modelWithContentsOfURL:compiledModelURL configuration:config2 error:&error];
106+
if (error || !model2) {
107+
XCTFail(@"Failed to load model: %@", error.localizedDescription);
108+
return;
109+
}
110+
NSMutableDictionary *inputs2 = DummyInputsForModel(model2, &error);
111+
if (error || !inputs2) {
112+
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
113+
return;
114+
}
115+
MLDictionaryFeatureProvider *featureProvider2 = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs2 error:&error];
116+
if (error || !featureProvider) {
117+
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
118+
return;
119+
}
120+
121+
122+
MLState *state = [model2 newState];
123+
[testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]]
124+
block:^{
125+
NSError *error = nil;
126+
127+
// Prefill
128+
id<MLFeatureProvider> prediction;
129+
prediction = [model predictionFromFeatures:featureProvider error:&error];
130+
if (error) {
131+
XCTFail(@"Prediction failed: %@", error.localizedDescription);
132+
}
133+
134+
// Decode
135+
id<MLFeatureProvider> prediction2;
136+
for (int i = 0; i < 128; i++) {
137+
prediction2 = [model2 predictionFromFeatures:featureProvider2 usingState:state error:&error];
138+
if (error) {
139+
XCTFail(@"Prediction failed: %@", error.localizedDescription);
140+
}
141+
}
142+
}];
99143
}
100-
}];
101-
}
102-
};
144+
};
103145
}
104146

105147
@end

0 commit comments

Comments
 (0)