Skip to content

Commit 0d70105

Browse files
committed
up
1 parent 80e3981 commit 0d70105

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

examples/models/llama/coreml_enumerated_shape.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def get_example_inputs(max_batch_size, args, coreml=False, use_enumerated_shapes
4949

5050

5151
# Batch with kv cache runs into issues
52-
max_batch_size = args.max_seq_length
52+
if args.use_kv_cache and not args.enable_dynamic_shape:
53+
max_batch_size = 1
54+
else:
55+
max_batch_size = args.max_seq_length
56+
5357
example_inputs = get_example_inputs(max_batch_size, args)
5458

5559

extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,20 @@ @implementation CoreMLTests
8585
XCTFail(@"Failed to load model: %@", error.localizedDescription);
8686
return;
8787
}
88-
// NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
89-
// if (error || !inputs) {
90-
// XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
91-
// return;
92-
// }
93-
// MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
94-
// if (error || !featureProvider) {
95-
// if (error) {
96-
// XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
97-
// } else {
98-
// XCTFail(@"Failed with unknown error");
99-
// }
100-
// return;
101-
// }
88+
NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
89+
if (error || !inputs) {
90+
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
91+
return;
92+
}
93+
MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
94+
if (error || !featureProvider) {
95+
if (error) {
96+
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
97+
} else {
98+
XCTFail(@"Failed with unknown error");
99+
}
100+
return;
101+
}
102102

103103

104104
MLMultiArray *tokensArray1x1 = [[MLMultiArray alloc] initWithShape:@[@1, @1] dataType:MLMultiArrayDataTypeInt32 error:&error];
@@ -122,6 +122,8 @@ @implementation CoreMLTests
122122
id<MLFeatureProvider> prediction;
123123
for (int i = 0; i < 50; i++) {
124124
// prediction = [model predictionFromFeatures:featureProvider usingState:state error:&error];
125+
126+
125127
if (i % 2 == 0) {
126128
prediction = [model predictionFromFeatures:features1x128 error:&error];
127129
} else {

0 commit comments

Comments
 (0)