|
16 | 16 | using namespace executorch::extension; |
17 | 17 | using namespace executorch::runtime; |
18 | 18 |
|
| 19 | +static inline EValue toEValue(ExecuTorchValue *value) { |
| 20 | + if (value.isTensor) { |
| 21 | + auto *nativeTensorPtr = value.tensorValue.nativeInstance; |
| 22 | + ET_CHECK(nativeTensorPtr); |
| 23 | + auto nativeTensor = *reinterpret_cast<TensorPtr *>(nativeTensorPtr); |
| 24 | + ET_CHECK(nativeTensor); |
| 25 | + return *nativeTensor; |
| 26 | + } |
| 27 | + ET_CHECK_MSG(false, "Unsupported ExecuTorchValue type"); |
| 28 | + return EValue(); |
| 29 | +} |
| 30 | + |
| 31 | +static inline ExecuTorchValue *toExecuTorchValue(EValue value) { |
| 32 | + if (value.isTensor()) { |
| 33 | + auto nativeInstance = make_tensor_ptr(value.toTensor()); |
| 34 | + return [ExecuTorchValue valueWithTensor:[[ExecuTorchTensor alloc] initWithNativeInstance:&nativeInstance]]; |
| 35 | + } |
| 36 | + ET_CHECK_MSG(false, "Unsupported EValue type"); |
| 37 | + return [ExecuTorchValue new]; |
| 38 | +} |
| 39 | + |
19 | 40 | @implementation ExecuTorchModule { |
20 | 41 | std::unique_ptr<Module> _module; |
21 | 42 | } |
@@ -94,4 +115,28 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { |
94 | 115 | return methods; |
95 | 116 | } |
96 | 117 |
|
| 118 | +- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName |
| 119 | + withInputs:(NSArray<ExecuTorchValue *> *)values |
| 120 | + error:(NSError **)error { |
| 121 | + std::vector<EValue> inputs; |
| 122 | + inputs.reserve(values.count); |
| 123 | + for (ExecuTorchValue *value in values) { |
| 124 | + inputs.push_back(toEValue(value)); |
| 125 | + } |
| 126 | + const auto result = _module->execute(methodName.UTF8String, inputs); |
| 127 | + if (!result.ok()) { |
| 128 | + if (error) { |
| 129 | + *error = [NSError errorWithDomain:ExecuTorchErrorDomain |
| 130 | + code:(NSInteger)result.error() |
| 131 | + userInfo:nil]; |
| 132 | + } |
| 133 | + return nil; |
| 134 | + } |
| 135 | + NSMutableArray<ExecuTorchValue *> *outputs = [NSMutableArray arrayWithCapacity:result->size()]; |
| 136 | + for (const auto &value : *result) { |
| 137 | + [outputs addObject:toExecuTorchValue(value)]; |
| 138 | + } |
| 139 | + return outputs; |
| 140 | +} |
| 141 | + |
97 | 142 | @end |
0 commit comments