diff --git a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm index ce685335767..28e5fffce21 100644 --- a/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm +++ b/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm @@ -7,12 +7,7 @@ */ #import "ResourceTestCase.h" - -#import -#import - -using namespace ::executorch::extension; -using namespace ::executorch::runtime; +#import "test_function.h" #define ASSERT_OK_OR_RETURN(value__) \ ({ \ @@ -37,7 +32,7 @@ @implementation GenericTests + (NSDictionary *)predicates { return @{ @"model" : ^BOOL(NSString *filename){ - return [filename hasSuffix:@".pte"]; + return [filename hasSuffix:@".mlpackage"]; }, }; } @@ -50,46 +45,9 @@ @implementation GenericTests [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] block:^{ - XCTAssertEqual( - Module(modelPath.UTF8String).load_forward(), - Error::Ok); + XCTAssertEqual(prefill_no_kv_cache(), false); }]; - }, - @"forward" : ^(XCTestCase *testCase) { - auto __block module = std::make_unique(modelPath.UTF8String); - - const auto method_meta = module->method_meta("forward"); - ASSERT_OK_OR_RETURN(method_meta); - - const auto num_inputs = method_meta->num_inputs(); - XCTAssertGreaterThan(num_inputs, 0); - - std::vector tensors; - tensors.reserve(num_inputs); - - for (auto index = 0; index < num_inputs; ++index) { - const auto input_tag = method_meta->input_tag(index); - ASSERT_OK_OR_RETURN(input_tag); - - switch (*input_tag) { - case Tag::Tensor: { - const auto tensor_meta = method_meta->input_tensor_meta(index); - ASSERT_OK_OR_RETURN(tensor_meta); - - const auto sizes = tensor_meta->sizes(); - tensors.emplace_back( - ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); - XCTAssertEqual(module->set_input(tensors.back(), index), Error::Ok); - } break; - default: - XCTFail("Unsupported tag %i at input %d", *input_tag, index); - } - } - [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] - block:^{ - XCTAssertEqual(module->forward().error(), Error::Ok); - }]; - }, + } }; } diff --git a/extension/benchmark/apple/Benchmark/Tests/test_function.h b/extension/benchmark/apple/Benchmark/Tests/test_function.h new file mode 100644 index 00000000000..b2f048e0577 --- /dev/null +++ b/extension/benchmark/apple/Benchmark/Tests/test_function.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#import +#import "model_no_kv.h" + +BOOL prefill_no_kv_cache(void) { + model_no_kv *model = [[model_no_kv alloc] init]; + + MLMultiArray *tokens = [[MLMultiArray alloc] initWithShape:(NSArray*)(@[@1, @512]) + dataType:(MLMultiArrayDataType)MLMultiArrayDataTypeInt32 + error:nil] ; + + for (int i = 0; i < 512; i++) { + tokens[i] = @2; + } + model_no_kvInput *inputs = [[model_no_kvInput alloc] initWithTokens:tokens]; + + for (int i = 0; i < 100; i++) { + model_no_kvOutput *output = [model predictionFromFeatures:inputs error:nil]; + } + + return YES; +}