1111#import < CoreML/CoreML.h>
1212
1313static MLMultiArray *DummyMultiArrayForFeature (MLFeatureDescription *feature, NSError **error) {
14- MLMultiArray *array = [[MLMultiArray alloc ] initWithShape: feature.multiArrayConstraint.shape
15- dataType: feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16- error: error];
17- for (auto index = 0 ; index < array.count ; ++index) {
18- array[index] = feature.multiArrayConstraint .dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0 ;
19- }
20- return array;
14+ MLMultiArray *array = [[MLMultiArray alloc ] initWithShape: feature.multiArrayConstraint.shape
15+ dataType: feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16+ error: error];
17+ for (auto index = 0 ; index < array.count ; ++index) {
18+ array[index] = feature.multiArrayConstraint .dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0 ;
19+ }
20+ return array;
2121}
2222
2323static NSMutableDictionary *DummyInputsForModel (MLModel *model, NSError **error) {
24- NSMutableDictionary *inputs = [NSMutableDictionary dictionary ];
25- NSDictionary <NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription .inputDescriptionsByName ;
26-
27- for (NSString *inputName in inputDescriptions) {
28- MLFeatureDescription *feature = inputDescriptions[inputName];
29-
30- switch (feature.type ) {
31- case MLFeatureTypeMultiArray: {
32- MLMultiArray *array = DummyMultiArrayForFeature (feature, error);
33- inputs[inputName] = [MLFeatureValue featureValueWithMultiArray: array];
34- break ;
35- }
36- case MLFeatureTypeInt64:
37- inputs[inputName] = [MLFeatureValue featureValueWithInt64: 1 ];
38- break ;
39- case MLFeatureTypeDouble:
40- inputs[inputName] = [MLFeatureValue featureValueWithDouble: 1.0 ];
41- break ;
42- case MLFeatureTypeString:
43- inputs[inputName] = [MLFeatureValue featureValueWithString: @" 1" ];
44- break ;
45- default :
46- break ;
24+ NSMutableDictionary *inputs = [NSMutableDictionary dictionary ];
25+ NSDictionary <NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription .inputDescriptionsByName ;
26+
27+ for (NSString *inputName in inputDescriptions) {
28+ MLFeatureDescription *feature = inputDescriptions[inputName];
29+
30+ switch (feature.type ) {
31+ case MLFeatureTypeMultiArray: {
32+ MLMultiArray *array = DummyMultiArrayForFeature (feature, error);
33+ inputs[inputName] = [MLFeatureValue featureValueWithMultiArray: array];
34+ break ;
35+ }
36+ case MLFeatureTypeInt64:
37+ inputs[inputName] = [MLFeatureValue featureValueWithInt64: 1 ];
38+ break ;
39+ case MLFeatureTypeDouble:
40+ inputs[inputName] = [MLFeatureValue featureValueWithDouble: 1.0 ];
41+ break ;
42+ case MLFeatureTypeString:
43+ inputs[inputName] = [MLFeatureValue featureValueWithString: @" 1" ];
44+ break ;
45+ default :
46+ break ;
47+ }
4748 }
48- }
49- return inputs;
49+ return inputs;
5050}
5151
5252@interface CoreMLTests : ResourceTestCase
@@ -55,51 +55,93 @@ @interface CoreMLTests : ResourceTestCase
5555@implementation CoreMLTests
5656
5757+ (NSArray <NSString *> *)directories {
58- return @[@" Resources" ];
58+ return @[@" Resources" ];
5959}
6060
6161+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
62- return @{ @" model" : ^BOOL (NSString *filename) {
63- return [filename hasSuffix: @" .mlpackage" ];
64- }};
62+ return @{ @" model" : ^BOOL (NSString *filename) {
63+ return [filename hasSuffix: @" combined .mlpackage" ];
64+ }};
6565}
6666
6767+ (NSDictionary <NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources : (NSDictionary <NSString *, NSString *> *)resources {
68- NSString *modelPath = resources[@" model" ];
69-
70- return @{
71- @" prediction" : ^(XCTestCase *testCase) {
72- NSError *error = nil ;
73- NSURL *compiledModelURL = [MLModel compileModelAtURL: [NSURL fileURLWithPath: modelPath] error: &error];
74- if (error || !compiledModelURL) {
75- XCTFail (@" Failed to compile model: %@ " , error.localizedDescription );
76- return ;
77- }
78- MLModel *model = [MLModel modelWithContentsOfURL: compiledModelURL error: &error];
79- if (error || !model) {
80- XCTFail (@" Failed to load model: %@ " , error.localizedDescription );
81- return ;
82- }
83- NSMutableDictionary *inputs = DummyInputsForModel (model, &error);
84- if (error || !inputs) {
85- XCTFail (@" Failed to prepare inputs: %@ " , error.localizedDescription );
86- return ;
87- }
88- MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc ] initWithDictionary: inputs error: &error];
89- if (error || !featureProvider) {
90- XCTFail (@" Failed to create input provider: %@ " , error.localizedDescription );
91- return ;
92- }
93- [testCase measureWithMetrics: @[[XCTClockMetric new ], [XCTMemoryMetric new ]]
94- block: ^{
95- NSError *error = nil ;
96- id <MLFeatureProvider> prediction = [model predictionFromFeatures: featureProvider error: &error];
97- if (error || !prediction) {
98- XCTFail (@" Prediction failed: %@ " , error.localizedDescription );
68+ NSString *modelPath = resources[@" model" ];
69+
70+ return @{
71+ @" prediction" : ^(XCTestCase *testCase) {
72+
73+ NSError *error = nil ;
74+ NSURL *compiledModelURL = [MLModel compileModelAtURL: [NSURL fileURLWithPath: modelPath] error: &error];
75+ if (error || !compiledModelURL) {
76+ XCTFail (@" Failed to compile model: %@ " , error.localizedDescription );
77+ return ;
78+ }
79+
80+ // Model1
81+ MLModelConfiguration *config = [[MLModelConfiguration alloc ] init ];
82+ config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
83+ config.functionName = @" model1" ;
84+ MLModel *model = [MLModel modelWithContentsOfURL: compiledModelURL configuration: config error: &error];
85+ if (error || !model) {
86+ XCTFail (@" Failed to load model: %@ " , error.localizedDescription );
87+ return ;
88+ }
89+ NSMutableDictionary *inputs = DummyInputsForModel (model, &error);
90+ if (error || !inputs) {
91+ XCTFail (@" Failed to prepare inputs: %@ " , error.localizedDescription );
92+ return ;
93+ }
94+ MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc ] initWithDictionary: inputs error: &error];
95+ if (error || !featureProvider) {
96+ XCTFail (@" Failed to create input provider: %@ " , error.localizedDescription );
97+ return ;
98+ }
99+
100+
101+ // Model2
102+ MLModelConfiguration *config2 = [[MLModelConfiguration alloc ] init ];
103+ config2.computeUnits = MLComputeUnitsCPUOnly;
104+ config2.functionName = @" model2" ;
105+ MLModel *model2 = [MLModel modelWithContentsOfURL: compiledModelURL configuration: config2 error: &error];
106+ if (error || !model2) {
107+ XCTFail (@" Failed to load model: %@ " , error.localizedDescription );
108+ return ;
109+ }
110+ NSMutableDictionary *inputs2 = DummyInputsForModel (model2, &error);
111+ if (error || !inputs2) {
112+ XCTFail (@" Failed to prepare inputs: %@ " , error.localizedDescription );
113+ return ;
114+ }
115+ MLDictionaryFeatureProvider *featureProvider2 = [[MLDictionaryFeatureProvider alloc ] initWithDictionary: inputs2 error: &error];
116+ if (error || !featureProvider) {
117+ XCTFail (@" Failed to create input provider: %@ " , error.localizedDescription );
118+ return ;
119+ }
120+
121+
122+ MLState *state = [model2 newState ];
123+ [testCase measureWithMetrics: @[[XCTClockMetric new ], [XCTMemoryMetric new ]]
124+ block: ^{
125+ NSError *error = nil ;
126+
127+ // Prefill
128+ id <MLFeatureProvider> prediction;
129+ prediction = [model predictionFromFeatures: featureProvider error: &error];
130+ if (error) {
131+ XCTFail (@" Prediction failed: %@ " , error.localizedDescription );
132+ }
133+
134+ // Decode
135+ id <MLFeatureProvider> prediction2;
136+ for (int i = 0 ; i < 128 ; i++) {
137+ prediction2 = [model2 predictionFromFeatures: featureProvider2 usingState: state error: &error];
138+ if (error) {
139+ XCTFail (@" Prediction failed: %@ " , error.localizedDescription );
140+ }
141+ }
142+ }];
99143 }
100- }];
101- }
102- };
144+ };
103145}
104146
105147@end
0 commit comments