Skip to content

Commit 87eddfd

Browse files
authored
Update ETCoreMLModelManager.mm
1 parent 7edbf9b commit 87eddfd

File tree

1 file changed

+46
-40
lines changed

1 file changed

+46
-40
lines changed

backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,8 @@ - (BOOL)executeModelWithHandle:(ModelHandle *)handle
709709
args.count);
710710
return NO;
711711
}
712+
NSError *localError = nil;
713+
NSArray<MLMultiArray *> *modelOutputs = nil;
712714
@autoreleasepool {
713715
NSArray<MLMultiArray *> *inputs = [args subarrayWithRange:NSMakeRange(0, model.orderedInputNames.count)];
714716
NSArray<MLMultiArray *> *outputs = [args subarrayWithRange:NSMakeRange(model.orderedInputNames.count, args.count - model.orderedInputNames.count)];
@@ -717,19 +719,22 @@ - (BOOL)executeModelWithHandle:(ModelHandle *)handle
717719
outputBackings = outputs;
718720
}
719721

720-
NSArray<MLMultiArray *> *modelOutputs = [self executeModelUsingExecutor:executor
721-
inputs:inputs
722-
outputBackings:outputBackings
723-
loggingOptions:loggingOptions
724-
eventLogger:eventLogger
725-
error:error];
726-
if (!modelOutputs) {
727-
return NO;
722+
modelOutputs = [self executeModelUsingExecutor:executor
723+
inputs:inputs
724+
outputBackings:outputBackings
725+
loggingOptions:loggingOptions
726+
eventLogger:eventLogger
727+
error:error];
728+
if (modelOutputs) {
729+
::set_outputs(outputs, modelOutputs);
728730
}
729-
730-
::set_outputs(outputs, modelOutputs);
731731
}
732-
732+
if (!modelOutputs) {
733+
if (error) {
734+
*error = localError;
735+
}
736+
return NO;
737+
}
733738
return YES;
734739
}
735740

@@ -760,40 +765,41 @@ - (BOOL)executeModelWithHandle:(ModelHandle *)handle
760765

761766
std::vector<executorchcoreml::MultiArray> inputArgs(argsVec.begin(), argsVec.begin() + model.orderedInputNames.count);
762767
std::vector<executorchcoreml::MultiArray> outputArgs(argsVec.begin() + model.orderedInputNames.count, argsVec.end());
768+
NSError *localError = nil;
769+
NSArray<MLMultiArray *> *modelOutputs = nil;
763770
@autoreleasepool {
764771
NSArray<MLMultiArray *> *inputs = [model prepareInputs:inputArgs error:error];
765-
if (!inputs) {
766-
return NO;
767-
}
768-
769-
NSArray<MLMultiArray *> *outputBackings = @[];
770-
if (executor.ignoreOutputBackings == NO) {
771-
outputBackings = [model prepareOutputBackings:outputArgs error:error];
772-
}
773-
774-
if (!outputBackings) {
775-
return NO;
776-
}
777-
778-
NSArray<MLMultiArray *> *modelOutputs = [self executeModelUsingExecutor:executor
779-
inputs:inputs
780-
outputBackings:outputBackings
781-
loggingOptions:loggingOptions
782-
eventLogger:eventLogger
783-
error:error];
784-
if (!modelOutputs) {
785-
return NO;
772+
if (inputs) {
773+
NSArray<MLMultiArray *> *outputBackings = @[];
774+
if (executor.ignoreOutputBackings == NO) {
775+
outputBackings = [model prepareOutputBackings:outputArgs error:error];
776+
}
777+
if (outputBackings) {
778+
modelOutputs = [self executeModelUsingExecutor:executor
779+
inputs:inputs
780+
outputBackings:outputBackings
781+
loggingOptions:loggingOptions
782+
eventLogger:eventLogger
783+
error:error];
784+
if (!modelOutputs) {
785+
// Resize for dynamic shapes
786+
for (int i = 0; i < outputArgs.size(); i++) {
787+
auto new_size = to_vector<size_t>(modelOutputs[i].shape);
788+
outputArgs[i].resize(new_size);
789+
argsVec[model.orderedInputNames.count + i].resize(new_size);
790+
}
791+
::set_outputs(outputArgs, modelOutputs);
792+
}
793+
}
786794
}
787-
788-
// Resize for dynamic shapes
789-
for (int i = 0; i < outputArgs.size(); i++) {
790-
auto new_size = to_vector<size_t>(modelOutputs[i].shape);
791-
outputArgs[i].resize(new_size);
792-
argsVec[model.orderedInputNames.count + i].resize(new_size);
795+
}
796+
if (!modelOutputs) {
797+
if (error) {
798+
*error = localError;
793799
}
794-
::set_outputs(outputArgs, modelOutputs);
795-
return YES;
800+
return NO;
796801
}
802+
return YES;
797803
}
798804

799805
- (BOOL)unloadModelWithHandle:(ModelHandle *)handle {

0 commit comments

Comments
 (0)