From 39315558c6216df557efa400a56910f5db94e4de Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Mon, 14 Oct 2024 15:17:50 -0700 Subject: [PATCH] Add CoreML tests. Summary: . Differential Revision: D64359459 --- .../Benchmark.xcodeproj/project.pbxproj | 4 + .../xcshareddata/xcschemes/Benchmark.xcscheme | 4 +- .../apple/Benchmark/Tests/CoreMLTests.mm | 105 ++++++++++++++++++ 3 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm diff --git a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj index 0e128813388..fe25a173843 100644 --- a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj +++ b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj @@ -28,6 +28,7 @@ 03DD00B22C8FE44600FE4619 /* backend_mps.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */; }; 03DD00B32C8FE44600FE4619 /* executorch.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A32C8FE44600FE4619 /* executorch.xcframework */; settings = {ATTRIBUTES = (Required, ); }; }; 03DD00B52C8FE44600FE4619 /* kernels_quantized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */; }; + 03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */ = {isa = PBXBuildFile; fileRef = 03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */; }; 03ED6D0F2C8AAFE900F2D6EE /* libsqlite3.0.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */; }; 03ED6D112C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */; }; 03ED6D132C8AAFF700F2D6EE /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */; }; @@ -90,6 +91,7 @@ 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_mps.xcframework; path = Frameworks/backend_mps.xcframework; sourceTree = ""; }; 03DD00A32C8FE44600FE4619 /* executorch.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = executorch.xcframework; path = Frameworks/executorch.xcframework; sourceTree = ""; }; 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_quantized.xcframework; path = Frameworks/kernels_quantized.xcframework; sourceTree = ""; }; + 03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CoreMLTests.mm; sourceTree = ""; }; 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.0.tbd; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/usr/lib/libsqlite3.0.tbd; sourceTree = DEVELOPER_DIR; }; 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShadersGraph.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework; sourceTree = DEVELOPER_DIR; }; 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = DEVELOPER_DIR; }; @@ -232,6 +234,7 @@ isa = PBXGroup; children = ( 032A73C92CAFBA8600932D36 /* LLaMA */, + 03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */, 03B2D3792C8A515C0046936E /* GenericTests.mm */, 03B019502C8A80D30044D558 /* Tests.xcconfig */, 037C96A02C8A570B00B3DF38 /* Tests.xctestplan */, @@ -388,6 +391,7 @@ 032A741E2CAFBB7800932D36 /* tiktoken.cpp in Sources */, 032A741F2CAFBB7800932D36 /* sampler.cpp in Sources */, 03B011912CAD114E00054791 /* ResourceTestCase.m in Sources */, + 03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */, 032A74232CAFC1B300932D36 /* runner.cpp in Sources */, 03B2D37A2C8A515C0046936E /* GenericTests.mm in Sources */, 032A73CA2CAFBA8600932D36 /* LLaMATests.mm in Sources */, diff --git a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/xcshareddata/xcschemes/Benchmark.xcscheme b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/xcshareddata/xcschemes/Benchmark.xcscheme index ebfe1e5fd35..c1e3aecd4cc 100644 --- a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/xcshareddata/xcschemes/Benchmark.xcscheme +++ b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/xcshareddata/xcschemes/Benchmark.xcscheme @@ -25,8 +25,8 @@ + +static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) { + MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape + dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble + error:error]; + for (auto index = 0; index < array.count; ++index) { + array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0; + } + return array; +} + +static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) { + NSMutableDictionary *inputs = [NSMutableDictionary dictionary]; + NSDictionary *inputDescriptions = model.modelDescription.inputDescriptionsByName; + + for (NSString *inputName in inputDescriptions) { + MLFeatureDescription *feature = inputDescriptions[inputName]; + + switch (feature.type) { + case MLFeatureTypeMultiArray: { + MLMultiArray *array = DummyMultiArrayForFeature(feature, error); + inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array]; + break; + } + case MLFeatureTypeInt64: + inputs[inputName] = [MLFeatureValue featureValueWithInt64:1]; + break; + case MLFeatureTypeDouble: + inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0]; + break; + case MLFeatureTypeString: + inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"]; + break; + default: + break; + } + } + return inputs; +} + +@interface CoreMLTests : ResourceTestCase +@end + +@implementation CoreMLTests + ++ (NSArray *)directories { + return @[@"Resources"]; +} + ++ (NSDictionary *)predicates { + return @{ @"model" : ^BOOL(NSString *filename) { + return [filename hasSuffix:@".mlpackage"]; + }}; +} + ++ (NSDictionary *)dynamicTestsForResources:(NSDictionary *)resources { + NSString *modelPath = resources[@"model"]; + + return @{ + @"prediction" : ^(XCTestCase *testCase) { + NSError *error = nil; + NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error]; + if (error || !compiledModelURL) { + XCTFail(@"Failed to compile model: %@", error.localizedDescription); + return; + } + MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error]; + if (error || !model) { + XCTFail(@"Failed to load model: %@", error.localizedDescription); + return; + } + NSMutableDictionary *inputs = DummyInputsForModel(model, &error); + if (error || !inputs) { + XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription); + return; + } + MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error]; + if (error || !featureProvider) { + XCTFail(@"Failed to create input provider: %@", error.localizedDescription); + return; + } + [testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]] + block:^{ + NSError *error = nil; + id prediction = [model predictionFromFeatures:featureProvider error:&error]; + if (error || !prediction) { + XCTFail(@"Prediction failed: %@", error.localizedDescription); + } + }]; + } + }; +} + +@end