|
15 | 15 | #import <XCTest/XCTest.h> |
16 | 16 | #import <executorch/runtime/platform/runtime.h> |
17 | 17 | #import <model_logging_options.h> |
| 18 | +#import <multiarray.h> |
| 19 | + |
| 20 | +using namespace executorchcoreml; |
18 | 21 |
|
19 | 22 | @interface ETCoreMLModelManagerTests : XCTestCase |
20 | 23 |
|
@@ -110,7 +113,7 @@ - (void)testAddModelExecution { |
110 | 113 | XCTAssertNotNil(inputs); |
111 | 114 | MLMultiArray *output = [ETCoreMLTestUtils filledMultiArrayWithShape:inputs[0].shape dataType:inputs[0].dataType repeatedValue:@(0) error:&localError]; |
112 | 115 | NSArray<MLMultiArray *> *args = [inputs arrayByAddingObject:output]; |
113 | | - XCTAssertTrue([self.modelManager executeModelWithHandle:handle |
| 116 | + XCTAssertTrue([self.modelManager executeModelWithHandle:handle |
114 | 117 | args:args |
115 | 118 | loggingOptions:executorchcoreml::ModelLoggingOptions() |
116 | 119 | eventLogger:nullptr |
@@ -148,4 +151,77 @@ - (void)testMulModelExecution { |
148 | 151 | } |
149 | 152 | } |
150 | 153 |
|
| 154 | +// See https://github.com/pytorch/executorch/pull/10465 |
| 155 | +- (void)testAutoreleasepoolError { |
| 156 | + NSURL *modelURL = [self.class bundledResourceWithName:@"add_coreml_all" extension:@"bin"]; |
| 157 | + NSError *localError = nil; |
| 158 | + XCTAssertNotNil(modelURL); |
| 159 | + |
| 160 | + NSData *modelData = [NSData dataWithContentsOfURL:modelURL]; |
| 161 | + MLModelConfiguration *configuration = [[MLModelConfiguration alloc] init]; |
| 162 | + configuration.computeUnits = MLComputeUnitsAll; |
| 163 | + ModelHandle *modelHandle = [self.modelManager loadModelFromAOTData:modelData |
| 164 | + configuration:configuration |
| 165 | + error:&localError]; |
| 166 | + XCTAssert(modelHandle); |
| 167 | + |
| 168 | + ETCoreMLModel *model = [self.modelManager modelWithHandle:modelHandle]; |
| 169 | + XCTAssert(model); |
| 170 | + |
| 171 | + NSArray<MLMultiArray *> *inputArrays = |
| 172 | + [ETCoreMLTestUtils inputsForModel:model repeatedValues:@[@(2), @(3)] error:&localError]; |
| 173 | + XCTAssert(inputArrays); |
| 174 | + |
| 175 | + std::vector<MultiArray> multiArrays; |
| 176 | + multiArrays.reserve(inputArrays.count + model.orderedOutputNames.count); |
| 177 | + for (MLMultiArray *array in inputArrays) { |
| 178 | + auto dataTypeOpt = to_multiarray_data_type(array.dataType); |
| 179 | + XCTAssert(dataTypeOpt.has_value()); |
| 180 | + auto dataType = dataTypeOpt.value(); |
| 181 | + |
| 182 | + std::vector<size_t> dims; |
| 183 | + for (NSNumber *n in array.shape) { |
| 184 | + dims.push_back(n.unsignedLongValue); |
| 185 | + } |
| 186 | + |
| 187 | + std::vector<ssize_t> strides(dims.size()); |
| 188 | + ssize_t currentStride = 1; |
| 189 | + for (NSInteger i = dims.size() - 1; i >= 0; --i) { |
| 190 | + strides[i] = currentStride; |
| 191 | + currentStride *= dims[i]; |
| 192 | + } |
| 193 | + |
| 194 | + multiArrays.emplace_back(array.dataPointer, |
| 195 | + MultiArray::MemoryLayout(dataType, dims, strides)); |
| 196 | + } |
| 197 | + |
| 198 | + auto inputLayout = multiArrays[0].layout(); |
| 199 | + size_t bufferSize = inputLayout.num_bytes(); |
| 200 | + for (NSUInteger i = 0; i < model.orderedOutputNames.count; ++i) { |
| 201 | + multiArrays.emplace_back(calloc(1, bufferSize), inputLayout); |
| 202 | + } |
| 203 | + // corrupt first input shape to force error |
| 204 | + { |
| 205 | + auto originalLayout = multiArrays[0].layout(); |
| 206 | + auto corruptedDims = originalLayout.shape(); |
| 207 | + corruptedDims[0] += 1; |
| 208 | + multiArrays[0] = MultiArray(multiArrays[0].data(), |
| 209 | + MultiArray::MemoryLayout(originalLayout.dataType(), |
| 210 | + corruptedDims, |
| 211 | + originalLayout.strides())); |
| 212 | + } |
| 213 | + |
| 214 | + BOOL success = [self.modelManager executeModelWithHandle:modelHandle |
| 215 | + argsVec:multiArrays |
| 216 | + loggingOptions:ModelLoggingOptions() |
| 217 | + eventLogger:nullptr |
| 218 | + error:&localError]; |
| 219 | + XCTAssertFalse(success); |
| 220 | + XCTAssertNotNil(localError); |
| 221 | + |
| 222 | + for (size_t i = inputArrays.count; i < multiArrays.size(); ++i) { |
| 223 | + free(multiArrays[i].data()); |
| 224 | + } |
| 225 | +} |
| 226 | + |
151 | 227 | @end |
0 commit comments