diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index a73379ff4c7..94d0a8356cd 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -104,6 +104,17 @@ __attribute__((deprecated("This API is experimental."))) */ - (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:)); +/** + * Retrieves the set of method names available in the loaded program. + * + * The method names are returned as an unordered set of strings. The program and methods + * are loaded as needed. + * + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An unordered set of method names, or nil in case of an error. + */ +- (nullable NSSet *)methodNames:(NSError **)error; + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 246d6324de0..51c2c024fbd 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -77,4 +77,21 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { return _module->is_method_loaded(methodName.UTF8String); } +- (nullable NSSet *)methodNames:(NSError **)error { + const auto result = _module->method_names(); + if (!result.ok()) { + if (error) { + *error = [NSError errorWithDomain:ExecuTorchErrorDomain + code:(NSInteger)result.error() + userInfo:nil]; + } + return nil; + } + NSMutableSet *methods = [NSMutableSet setWithCapacity:result->size()]; + for (const auto &name : *result) { + [methods addObject:(NSString *)@(name.c_str())]; + } + return methods; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index e94820a43c3..515e6f5874f 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -39,4 +39,16 @@ class ModuleTest: XCTestCase { XCTAssertNoThrow(try module.load("forward")) XCTAssertTrue(module.isLoaded("forward")) } + + func testMethodNames() { + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path(forResource: "add", ofType: "pte") else { + XCTFail("Couldn't find the model file") + return + } + let module = Module(filePath: modelPath) + var methodNames: Set? + XCTAssertNoThrow(methodNames = try module.methodNames()) + XCTAssertEqual(methodNames, Set(["forward"])) + } }