diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index 94d0a8356cd..9229c60512a 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -115,6 +115,21 @@ __attribute__((deprecated("This API is experimental."))) */ - (nullable NSSet *)methodNames:(NSError **)error; +/** + * Executes a specific method with the provided input values. + * + * The method is loaded on demand if not already loaded. + * + * @param methodName A string representing the method name. + * @param values An NSArray of ExecuTorchValue objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)executeMethod:(NSString *)methodName + withInputs:(NSArray *)values + error:(NSError **)error + NS_SWIFT_NAME(execute(_:_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 51c2c024fbd..5142a969c8f 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -16,6 +16,27 @@ using namespace executorch::extension; using namespace executorch::runtime; +static inline EValue toEValue(ExecuTorchValue *value) { + if (value.isTensor) { + auto *nativeTensorPtr = value.tensorValue.nativeInstance; + ET_CHECK(nativeTensorPtr); + auto nativeTensor = *reinterpret_cast(nativeTensorPtr); + ET_CHECK(nativeTensor); + return *nativeTensor; + } + ET_CHECK_MSG(false, "Unsupported ExecuTorchValue type"); + return EValue(); +} + +static inline ExecuTorchValue *toExecuTorchValue(EValue value) { + if (value.isTensor()) { + auto nativeInstance = make_tensor_ptr(value.toTensor()); + return [ExecuTorchValue valueWithTensor:[[ExecuTorchTensor alloc] initWithNativeInstance:&nativeInstance]]; + } + ET_CHECK_MSG(false, "Unsupported EValue type"); + return [ExecuTorchValue new]; +} + @implementation ExecuTorchModule { std::unique_ptr _module; } @@ -94,4 +115,28 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { return methods; } +- (nullable NSArray *)executeMethod:(NSString *)methodName + withInputs:(NSArray *)values + error:(NSError **)error { + std::vector inputs; + inputs.reserve(values.count); + for (ExecuTorchValue *value in values) { + inputs.push_back(toEValue(value)); + } + const auto result = _module->execute(methodName.UTF8String, inputs); + if (!result.ok()) { + if (error) { + *error = [NSError errorWithDomain:ExecuTorchErrorDomain + code:(NSInteger)result.error() + userInfo:nil]; + } + return nil; + } + NSMutableArray *outputs = [NSMutableArray arrayWithCapacity:result->size()]; + for (const auto &value : *result) { + [outputs addObject:toExecuTorchValue(value)]; + } + return outputs; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index 515e6f5874f..feaa0f19826 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -51,4 +51,25 @@ class ModuleTest: XCTestCase { XCTAssertNoThrow(methodNames = try module.methodNames()) XCTAssertEqual(methodNames, Set(["forward"])) } + + func testExecute() { + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else { + XCTFail("Couldn't find the model file") + return + } + let module = Module(filePath: modelPath) + var inputData: [Float] = [1.0] + let inputTensor = inputData.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float) + } + let inputs = [Value(inputTensor), Value(inputTensor)] + var outputs: [Value]? + XCTAssertNoThrow(outputs = try module.execute("forward", inputs)) + var outputData: [Float] = [2.0] + let outputTensor = outputData.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float, shapeDynamism: .static) + } + XCTAssertEqual(outputs?[0].tensor, outputTensor) + } }