Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ __attribute__((deprecated("This API is experimental.")))
*/
- (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:));

/**
* Unloads a method and releases its native resources and planned buffers.
*
* @param methodName The method to unload.
* @return YES if the method was unloaded; NO if it was not loaded at all.
*/
- (BOOL)unloadMethod:(NSString *)methodName NS_SWIFT_NAME(unload(_:));

/**
* Retrieves the set of method names available in the loaded program.
*
Expand Down
7 changes: 7 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ - (BOOL)isMethodLoaded:(NSString *)methodName {
return _module->is_method_loaded(methodName.UTF8String);
}

- (BOOL)unloadMethod:(NSString *)methodName {
const auto didUnload = _module->unload_method(methodName.UTF8String);
[_inputs removeObjectForKey:methodName];
[_outputs removeObjectForKey:methodName];
return didUnload;
}

- (nullable NSSet<NSString *> *)methodNames:(NSError **)error {
const auto result = _module->method_names();
if (!result.ok()) {
Expand Down
22 changes: 22 additions & 0 deletions extension/apple/ExecuTorch/__tests__/ModuleTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,26 @@ class ModuleTest: XCTestCase {

XCTAssertThrowsError(try module.setInputs(Tensor<Float>([1])))
}

func testUnloadMethod() {
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
XCTFail("Couldn't find the model file")
return
}
let module = Module(filePath: modelPath)
XCTAssertNoThrow(try module.load("forward"))
XCTAssertTrue(module.isLoaded("forward"))

XCTAssertNoThrow(try module.setInputs(Tensor<Float>([1]), Tensor<Float>([2])))
XCTAssertEqual(try module.forward(), Tensor<Float>([3]))

XCTAssertTrue(module.unload("forward"))
XCTAssertFalse(module.isLoaded("forward"))
XCTAssertFalse(module.unload("forward"))

XCTAssertThrowsError(try module.forward())
XCTAssertTrue(module.isLoaded("forward"))
XCTAssertNoThrow(try module.setInputs(Tensor<Float>([2]), Tensor<Float>([3])))
XCTAssertEqual(try module.forward(), Tensor<Float>([5]))
}
}
Loading