Skip to content

Commit 1d5a214

Browse files
authored
Introduce unload method API to Module. (#13364)
Summary: . Differential Revision: D80149419
1 parent 3254ddf commit 1d5a214

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ __attribute__((deprecated("This API is experimental.")))
187187
*/
188188
- (BOOL)isMethodLoaded:(NSString *)methodName NS_SWIFT_NAME(isLoaded(_:));
189189

190+
/**
191+
* Unloads a method and releases its native resources and planned buffers.
192+
*
193+
* @param methodName The method to unload.
194+
* @return YES if the method was unloaded; NO if it was not loaded at all.
195+
*/
196+
- (BOOL)unloadMethod:(NSString *)methodName NS_SWIFT_NAME(unload(_:));
197+
190198
/**
191199
* Retrieves the set of method names available in the loaded program.
192200
*

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,13 @@ - (BOOL)isMethodLoaded:(NSString *)methodName {
304304
return _module->is_method_loaded(methodName.UTF8String);
305305
}
306306

307+
- (BOOL)unloadMethod:(NSString *)methodName {
308+
const auto didUnload = _module->unload_method(methodName.UTF8String);
309+
[_inputs removeObjectForKey:methodName];
310+
[_outputs removeObjectForKey:methodName];
311+
return didUnload;
312+
}
313+
307314
- (nullable NSSet<NSString *> *)methodNames:(NSError **)error {
308315
const auto result = _module->method_names();
309316
if (!result.ok()) {

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,26 @@ class ModuleTest: XCTestCase {
171171

172172
XCTAssertThrowsError(try module.setInputs(Tensor<Float>([1])))
173173
}
174+
175+
func testUnloadMethod() {
176+
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
177+
XCTFail("Couldn't find the model file")
178+
return
179+
}
180+
let module = Module(filePath: modelPath)
181+
XCTAssertNoThrow(try module.load("forward"))
182+
XCTAssertTrue(module.isLoaded("forward"))
183+
184+
XCTAssertNoThrow(try module.setInputs(Tensor<Float>([1]), Tensor<Float>([2])))
185+
XCTAssertEqual(try module.forward(), Tensor<Float>([3]))
186+
187+
XCTAssertTrue(module.unload("forward"))
188+
XCTAssertFalse(module.isLoaded("forward"))
189+
XCTAssertFalse(module.unload("forward"))
190+
191+
XCTAssertThrowsError(try module.forward())
192+
XCTAssertTrue(module.isLoaded("forward"))
193+
XCTAssertNoThrow(try module.setInputs(Tensor<Float>([2]), Tensor<Float>([3])))
194+
XCTAssertEqual(try module.forward(), Tensor<Float>([5]))
195+
}
174196
}

0 commit comments

Comments
 (0)