Skip to content

Commit ae74ef9

Browse files
Overloads for Module execute API. (#9688)
Summary: #8363 Reviewed By: mergennachin Differential Revision: D71921054 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 80138b8 commit ae74ef9

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchModule.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,49 @@ __attribute__((deprecated("This API is experimental.")))
130130
error:(NSError **)error
131131
NS_SWIFT_NAME(execute(_:_:));
132132

133+
/**
134+
* Executes a specific method with the provided single input value.
135+
*
136+
* The method is loaded on demand if not already loaded.
137+
*
138+
* @param methodName A string representing the method name.
139+
* @param value An ExecuTorchValue object representing the input.
140+
* @param error A pointer to an NSError pointer that is set if an error occurs.
141+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
142+
*/
143+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
144+
withInput:(ExecuTorchValue *)value
145+
error:(NSError **)error
146+
NS_SWIFT_NAME(execute(_:_:));
147+
148+
/**
149+
* Executes a specific method with no input values.
150+
*
151+
* The method is loaded on demand if not already loaded.
152+
*
153+
* @param methodName A string representing the method name.
154+
* @param error A pointer to an NSError pointer that is set if an error occurs.
155+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
156+
*/
157+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
158+
error:(NSError **)error
159+
NS_SWIFT_NAME(execute(_:));
160+
161+
/**
162+
* Executes a specific method with the provided input tensors.
163+
*
164+
* The method is loaded on demand if not already loaded.
165+
*
166+
* @param methodName A string representing the method name.
167+
* @param tensors An NSArray of ExecuTorchTensor objects representing the inputs.
168+
* @param error A pointer to an NSError pointer that is set if an error occurs.
169+
* @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error.
170+
*/
171+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
172+
withTensors:(NSArray<ExecuTorchTensor *> *)tensors
173+
error:(NSError **)error
174+
NS_SWIFT_NAME(execute(_:_:));
175+
133176
+ (instancetype)new NS_UNAVAILABLE;
134177
- (instancetype)init NS_UNAVAILABLE;
135178

extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,31 @@ - (BOOL)isMethodLoaded:(NSString *)methodName {
139139
return outputs;
140140
}
141141

142+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
143+
withInput:(ExecuTorchValue *)value
144+
error:(NSError **)error {
145+
return [self executeMethod:methodName
146+
withInputs:@[value]
147+
error:error];
148+
}
149+
150+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
151+
error:(NSError **)error {
152+
return [self executeMethod:methodName
153+
withInputs:@[]
154+
error:error];
155+
}
156+
157+
- (nullable NSArray<ExecuTorchValue *> *)executeMethod:(NSString *)methodName
158+
withTensors:(NSArray<ExecuTorchTensor *> *)tensors
159+
error:(NSError **)error {
160+
NSMutableArray<ExecuTorchValue *> *values = [NSMutableArray arrayWithCapacity:tensors.count];
161+
for (ExecuTorchTensor *tensor in tensors) {
162+
[values addObject:[ExecuTorchValue valueWithTensor:tensor]];
163+
}
164+
return [self executeMethod:methodName
165+
withInputs:values
166+
error:error];
167+
}
168+
142169
@end

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class ModuleTest: XCTestCase {
6363
let inputTensor = inputData.withUnsafeMutableBytes {
6464
Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float)
6565
}
66-
let inputs = [Value(inputTensor), Value(inputTensor)]
66+
let inputs = [inputTensor, inputTensor]
6767
var outputs: [Value]?
6868
XCTAssertNoThrow(outputs = try module.execute("forward", inputs))
6969
var outputData: [Float] = [2.0]

0 commit comments

Comments
 (0)