diff --git a/examples/apple/coreml/llama/create_multifunctions.py b/examples/apple/coreml/llama/create_multifunctions.py new file mode 100644 index 00000000000..7772f076c01 --- /dev/null +++ b/examples/apple/coreml/llama/create_multifunctions.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + +import coremltools as ct + + +def extract_models(pte_path: str, output_dir: str) -> list[str]: + """ + Extract CoreML models from a PTE file. + Returns list of paths to extracted .mlpackage files. + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Run the extraction script + script_path = Path(__file__).parent.parent / "scripts" / "extract_coreml_models.py" + + # Save current directory and change to output dir (extract script outputs to cwd) + original_cwd = os.getcwd() + os.chdir(output_dir) + + try: + result = subprocess.run( + [sys.executable, str(script_path), "-m", pte_path], + capture_output=True, + text=True + ) + if result.returncode != 0: + print(f"Error extracting models: {result.stderr}") + sys.exit(1) + print(result.stdout) + finally: + os.chdir(original_cwd) + + # Find extracted mlpackage files + extracted_dir = Path(output_dir) / "extracted_coreml_models" + + # Debug: print what we find + print(f" Looking in: {extracted_dir}") + for model_dir in sorted(extracted_dir.iterdir()): + print(f" {model_dir.name}/") + if model_dir.is_dir(): + for item in list(model_dir.iterdir())[:10]: + print(f" {item.name}") + + model_paths = [] + for model_dir in sorted(extracted_dir.iterdir()): + if model_dir.is_dir(): + # Look for .mlpackage inside the model directory + found = False + for item in model_dir.iterdir(): + if item.suffix == ".mlpackage": + model_paths.append(str(item)) + found = True + break + + # If no .mlpackage found, check for lowered_module directory + if not found: + lowered_module = model_dir / "lowered_module" + if lowered_module.exists() and lowered_module.is_dir(): + # Debug: show contents of lowered_module + print(f" Contents of {lowered_module}:") + for item in list(lowered_module.iterdir())[:10]: + print(f" {item.name}") + + # Look for .mlpackage inside lowered_module + for item in lowered_module.iterdir(): + if item.suffix == ".mlpackage": + model_paths.append(str(item)) + found = True + break + + # If still not found, look for model.mlmodel file + if not found: + mlmodel_file = lowered_module / "model.mlmodel" + if mlmodel_file.exists(): + # Load and save as mlpackage + mlpackage_path = model_dir / f"{model_dir.name}.mlpackage" + model = ct.models.MLModel(str(mlmodel_file)) + model.save(str(mlpackage_path)) + model_paths.append(str(mlpackage_path)) + found = True + + return model_paths + + +def create_multifunction_model( + prefill_mlpackage: str, + decode_mlpackage: str, + output_path: str, + compile_model: bool +) -> str: + """ + Create a multifunction model combining prefill and decode. + Returns the path to the output model. + """ + desc = ct.utils.MultiFunctionDescriptor() + + desc.add_function( + prefill_mlpackage, + src_function_name="main", + target_function_name="prefill" + ) + desc.add_function( + decode_mlpackage, + src_function_name="main", + target_function_name="decode" + ) + + desc.default_function_name = "decode" + + if compile_model: + # Save mlpackage first, then compile + mlpackage_path = output_path + ".mlpackage" + ct.utils.save_multifunction(desc, mlpackage_path) + + compiled_path = ct.utils.compile_model(mlpackage_path) + dest_path = output_path + ".mlmodelc" + + if os.path.exists(dest_path): + shutil.rmtree(dest_path) + shutil.move(compiled_path, dest_path) + + # Clean up intermediate mlpackage + shutil.rmtree(mlpackage_path) + + print(f"Saved compiled model to {dest_path}") + return dest_path + else: + mlpackage_path = output_path + ".mlpackage" + ct.utils.save_multifunction(desc, mlpackage_path) + print(f"Saved model to {mlpackage_path}") + return mlpackage_path + + +def main(): + parser = argparse.ArgumentParser( + description="Create multifunction CoreML models from prefill/decode PTE files" + ) + parser.add_argument( + "--prefill_model", + required=True, + help="Path to the prefill PTE file" + ) + parser.add_argument( + "--decode_model", + required=True, + help="Path to the decode PTE file" + ) + parser.add_argument( + "--compile", + action="store_true", + default=False, + help="Compile the models to .mlmodelc format" + ) + parser.add_argument( + "--output_dir", + default=".", + help="Output directory for the multifunction models (default: current directory)" + ) + + args = parser.parse_args() + + # Create temp directories for extraction + temp_dir = Path(args.output_dir) / "temp_extraction" + prefill_extract_dir = temp_dir / "prefill" + decode_extract_dir = temp_dir / "decode" + + print("Extracting prefill models...") + prefill_models = extract_models(args.prefill_model, str(prefill_extract_dir)) + print(f"Found {len(prefill_models)} prefill models") + + print("Extracting decode models...") + decode_models = extract_models(args.decode_model, str(decode_extract_dir)) + print(f"Found {len(decode_models)} decode models") + + if len(prefill_models) != len(decode_models): + print(f"Error: Number of prefill models ({len(prefill_models)}) does not match decode models ({len(decode_models)})") + sys.exit(1) + + num_models = len(prefill_models) + print(f"\nCreating {num_models} multifunction models...") + + # Create multifunction models (mod1, mod2, mod3, ...) + for i in range(num_models): + model_num = i + 1 + output_path = str(Path(args.output_dir) / f"mod{model_num}") + + print(f"\nCreating mod{model_num}...") + print(f" Prefill: {prefill_models[i]}") + print(f" Decode: {decode_models[i]}") + + create_multifunction_model( + prefill_mlpackage=prefill_models[i], + decode_mlpackage=decode_models[i], + output_path=output_path, + compile_model=args.compile + ) + + # Clean up temp directory + print("\nCleaning up temporary files...") + try: + shutil.rmtree(temp_dir) + except OSError as e: + print(f"Warning: Could not fully clean up temp directory: {e}") + print(f"You may want to manually delete: {temp_dir}") + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm b/extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm index d175b30c8ea..d07f1d40cd2 100644 --- a/extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm +++ b/extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm @@ -6,18 +6,19 @@ * LICENSE file in the root directory of this source tree. */ +#import +#import +#import + #import "ResourceTestCase.h" -#import +static const int kNumModels = 3; static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) { - MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape - dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble - error:error]; - for (auto index = 0; index < array.count; ++index) { - array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0; - } - return array; + MLMultiArrayConstraint *constraint = feature.multiArrayConstraint; + return [[MLMultiArray alloc] initWithShape:constraint.shape + dataType:constraint.dataType + error:error]; } static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) { @@ -49,6 +50,25 @@ return inputs; } +static MLModel *LoadModelWithFunction(NSURL *compiledModelURL, NSString *functionName, NSError **error) { + MLModelConfiguration *config = [[MLModelConfiguration alloc] init]; + config.functionName = functionName; + config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; + return [MLModel modelWithContentsOfURL:compiledModelURL configuration:config error:error]; +} + +// Returns compiled model URL. If path is already .mlmodelc, returns it directly. +// If path is .mlpackage, compiles it and returns the compiled URL. +static NSURL *GetCompiledModelURL(NSString *modelPath, NSError **error) { + if ([modelPath hasSuffix:@".mlmodelc"]) { + // Already compiled + return [NSURL fileURLWithPath:modelPath]; + } else { + // Needs compilation + return [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:error]; + } +} + @interface CoreMLTests : ResourceTestCase @end @@ -59,45 +79,185 @@ @implementation CoreMLTests } + (NSDictionary *)predicates { - return @{ @"model" : ^BOOL(NSString *filename) { - return [filename hasSuffix:@".mlpackage"]; - }}; + return @{ + @"mod1" : ^BOOL(NSString *filename) { + return [filename hasSuffix:@"mod1.mlpackage"] || [filename hasSuffix:@"mod1.mlmodelc"]; + }, + @"mod2" : ^BOOL(NSString *filename) { + return [filename hasSuffix:@"mod2.mlpackage"] || [filename hasSuffix:@"mod2.mlmodelc"]; + }, + @"mod3" : ^BOOL(NSString *filename) { + return [filename hasSuffix:@"mod3.mlpackage"] || [filename hasSuffix:@"mod3.mlmodelc"]; + } + }; } + (NSDictionary *)dynamicTestsForResources:(NSDictionary *)resources { - NSString *modelPath = resources[@"model"]; + NSString *mod1Path = resources[@"mod1"]; + NSString *mod2Path = resources[@"mod2"]; + NSString *mod3Path = resources[@"mod3"]; return @{ - @"prediction" : ^(XCTestCase *testCase) { + @"multifunction" : ^(XCTestCase *testCase) { + const BOOL kEnableDecode = YES; + const BOOL kEnableMod1 = YES; + const BOOL kEnableMod2 = YES; + const BOOL kEnableMod3 = YES; + NSError *error = nil; - NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error]; - if (error || !compiledModelURL) { - XCTFail(@"Failed to compile model: %@", error.localizedDescription); - return; + NSArray *allModelPaths = @[mod1Path, mod2Path, mod3Path]; + NSArray *allModelNames = @[@"mod1", @"mod2", @"mod3"]; + NSArray *modelEnabled = @[@(kEnableMod1), @(kEnableMod2), @(kEnableMod3)]; + + // Filter to only enabled models + NSMutableArray *modelPaths = [NSMutableArray array]; + NSMutableArray *modelNames = [NSMutableArray array]; + for (int m = 0; m < kNumModels; ++m) { + if ([modelEnabled[m] boolValue]) { + [modelPaths addObject:allModelPaths[m]]; + [modelNames addObject:allModelNames[m]]; + } } - MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error]; - if (error || !model) { - XCTFail(@"Failed to load model: %@", error.localizedDescription); + + const int numEnabledModels = (int)[modelPaths count]; + if (numEnabledModels == 0) { + XCTFail(@"No models enabled"); return; } - NSMutableDictionary *inputs = DummyInputsForModel(model, &error); - if (error || !inputs) { - XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription); - return; + + // Get compiled model URLs (compile if needed) + NSMutableArray *compiledModelURLs = [NSMutableArray arrayWithCapacity:numEnabledModels]; + for (int m = 0; m < numEnabledModels; ++m) { + NSURL *compiledURL = GetCompiledModelURL(modelPaths[m], &error); + if (error || !compiledURL) { + XCTFail(@"Failed to get compiled model for %@: %@", modelNames[m], error.localizedDescription); + return; + } + [compiledModelURLs addObject:compiledURL]; } - MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error]; - if (error || !featureProvider) { - XCTFail(@"Failed to create input provider: %@", error.localizedDescription); - return; + + // Load prefill models for enabled models + NSMutableArray *prefillModels = [NSMutableArray arrayWithCapacity:numEnabledModels]; + for (int m = 0; m < numEnabledModels; ++m) { + MLModel *prefillModel = LoadModelWithFunction(compiledModelURLs[m], @"prefill", &error); + if (error || !prefillModel) { + XCTFail(@"Failed to load prefill model for %@: %@", modelNames[m], error.localizedDescription); + return; + } + [prefillModels addObject:prefillModel]; } - [testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]] - block:^{ - NSError *error = nil; - id prediction = [model predictionFromFeatures:featureProvider error:&error]; - if (error || !prediction) { - XCTFail(@"Prediction failed: %@", error.localizedDescription); + + // Load decode models for enabled models if decode is enabled + NSMutableArray *decodeModels = [NSMutableArray arrayWithCapacity:numEnabledModels]; + if (kEnableDecode) { + for (int m = 0; m < numEnabledModels; ++m) { + MLModel *decodeModel = LoadModelWithFunction(compiledModelURLs[m], @"decode", &error); + if (error || !decodeModel) { + XCTFail(@"Failed to load decode model for %@: %@", modelNames[m], error.localizedDescription); + return; + } + [decodeModels addObject:decodeModel]; + } + } + + // Prepare inputs for prefill models + NSMutableArray *prefillProviders = [NSMutableArray arrayWithCapacity:numEnabledModels]; + for (int m = 0; m < numEnabledModels; ++m) { + NSMutableDictionary *prefillInputs = DummyInputsForModel(prefillModels[m], &error); + if (error || !prefillInputs) { + XCTFail(@"Failed to prepare prefill inputs for %@: %@", modelNames[m], error.localizedDescription); + return; + } + MLDictionaryFeatureProvider *provider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:prefillInputs error:&error]; + if (error || !provider) { + XCTFail(@"Failed to create prefill input provider for %@: %@", modelNames[m], error.localizedDescription); + return; + } + [prefillProviders addObject:provider]; + } + + // Prepare inputs for decode models if enabled + NSMutableArray *decodeProviders = [NSMutableArray arrayWithCapacity:numEnabledModels]; + if (kEnableDecode) { + for (int m = 0; m < numEnabledModels; ++m) { + NSMutableDictionary *decodeInputs = DummyInputsForModel(decodeModels[m], &error); + if (error || !decodeInputs) { + XCTFail(@"Failed to prepare decode inputs for %@: %@", modelNames[m], error.localizedDescription); + return; + } + MLDictionaryFeatureProvider *provider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:decodeInputs error:&error]; + if (error || !provider) { + XCTFail(@"Failed to create decode input provider for %@: %@", modelNames[m], error.localizedDescription); + return; + } + [decodeProviders addObject:provider]; + } + } + + const int kNumPrefillIterations = 30; + const int kNumDecodeIterations = 50; + + // Start total timing + CFAbsoluteTime totalStart = CFAbsoluteTimeGetCurrent(); + + // Time prefill 1 (call prefill on all enabled models per iteration) + CFAbsoluteTime prefillStart = CFAbsoluteTimeGetCurrent(); + for (int i = 0; i < kNumPrefillIterations; ++i) { + for (int m = 0; m < numEnabledModels; ++m) { + NSError *prefillError = nil; + id prefillPrediction = [prefillModels[m] predictionFromFeatures:prefillProviders[m] error:&prefillError]; + if (prefillError || !prefillPrediction) { + XCTFail(@"Prefill 1 prediction failed on iteration %d for %@: %@", i, modelNames[m], prefillError.localizedDescription); + return; + } + } + } + CFAbsoluteTime prefillEnd = CFAbsoluteTimeGetCurrent(); + double prefillTimeMs = (prefillEnd - prefillStart) * 1000.0; + + // Time decode if enabled (call decode on all enabled models per iteration) + double decodeTimeMs = 0.0; + if (kEnableDecode) { + CFAbsoluteTime decodeStart = CFAbsoluteTimeGetCurrent(); + for (int i = 0; i < kNumDecodeIterations; ++i) { + for (int m = 0; m < numEnabledModels; ++m) { + NSError *decodeError = nil; + id decodePrediction = [decodeModels[m] predictionFromFeatures:decodeProviders[m] error:&decodeError]; + if (decodeError || !decodePrediction) { + XCTFail(@"Decode prediction failed on iteration %d for %@: %@", i, modelNames[m], decodeError.localizedDescription); + return; + } + } } - }]; + CFAbsoluteTime decodeEnd = CFAbsoluteTimeGetCurrent(); + decodeTimeMs = (decodeEnd - decodeStart) * 1000.0; + } + + // Time prefill 2 (call prefill on all enabled models per iteration) + CFAbsoluteTime prefill2Start = CFAbsoluteTimeGetCurrent(); + for (int i = 0; i < kNumPrefillIterations; ++i) { + for (int m = 0; m < numEnabledModels; ++m) { + NSError *prefillError = nil; + id prefillPrediction = [prefillModels[m] predictionFromFeatures:prefillProviders[m] error:&prefillError]; + if (prefillError || !prefillPrediction) { + XCTFail(@"Prefill 2 prediction failed on iteration %d for %@: %@", i, modelNames[m], prefillError.localizedDescription); + return; + } + } + } + CFAbsoluteTime prefill2End = CFAbsoluteTimeGetCurrent(); + double prefill2TimeMs = (prefill2End - prefill2Start) * 1000.0; + + // End total timing (includes prefill 1, decode, and prefill 2) + CFAbsoluteTime totalEnd = CFAbsoluteTimeGetCurrent(); + double totalTimeMs = (totalEnd - totalStart) * 1000.0; + + NSLog(@"=== Benchmark Results ==="); + NSLog(@"Prefill 1: %d iterations x %d models, total time: %.2f ms (%.2f ms/iter)", kNumPrefillIterations, numEnabledModels, prefillTimeMs, prefillTimeMs / kNumPrefillIterations); + NSLog(@"Decode: %d iterations x %d models, total time: %.2f ms (%.2f ms/iter)", kNumDecodeIterations, numEnabledModels, decodeTimeMs, decodeTimeMs / kNumDecodeIterations); + NSLog(@"Prefill 2: %d iterations x %d models, total time: %.2f ms (%.2f ms/iter)", kNumPrefillIterations, numEnabledModels, prefill2TimeMs, prefill2TimeMs / kNumPrefillIterations); + NSLog(@"Total time (prefill 1 + decode + prefill 2): %.2f ms", totalTimeMs); + NSLog(@"========================="); } }; }