Skip to content

Commit 6485e4f

Browse files
authored
Use ValueConstructible for execute/forward return type in Module. (#13090)
Summary: . Reviewed By: f-meloni Differential Revision: D79381683
1 parent 7750116 commit 6485e4f

File tree

3 files changed

+112
-2
lines changed

3 files changed

+112
-2
lines changed

docs/source/using-executorch-ios.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ let imageBuffer: UnsafeMutableRawPointer = ... // Existing image buffer
243243
let inputTensor = Tensor<Float>(&imageBuffer, shape: [1, 3, 224, 224])
244244
245245
// Execute the 'forward' method with the given input tensor and get an output tensor back.
246-
let outputTensor: Tensor<Float> = try module.forward(inputTensor)[0].tensor()!
246+
let outputTensor: Tensor<Float> = try module.forward(inputTensor)!
247247
248248
// Copy the tensor data into logits array for easier access.
249249
let logits = outputTensor.scalars()

extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,86 @@ public extension Module {
9393
try forward(inputs)
9494
}
9595
}
96+
97+
@available(*, deprecated, message: "This API is experimental.")
98+
public extension Module {
99+
/// Executes a specific method and decodes the outputs into `Output` generic type.
100+
///
101+
/// - Parameters:
102+
/// - method: The name of the method to execute.
103+
/// - inputs: An array of `ValueConvertible` inputs.
104+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
105+
/// - Throws: An error if loading, execution or result conversion fails.
106+
func execute<Output: ValueSequenceConstructible>(_ method: String, _ inputs: [ValueConvertible]) throws -> Output {
107+
try Output(__executeMethod(method, withInputs: inputs.map { $0.asValue() }))
108+
}
109+
110+
/// Executes a specific method with variadic inputs and decodes into `Output` generic type.
111+
///
112+
/// - Parameters:
113+
/// - method: The name of the method to execute.
114+
/// - inputs: A variadic list of `ValueConvertible` inputs.
115+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
116+
/// - Throws: An error if loading, execution or result conversion fails.
117+
func execute<Output: ValueSequenceConstructible>(_ method: String, _ inputs: ValueConvertible...) throws -> Output {
118+
try execute(method, inputs)
119+
}
120+
121+
/// Executes a specific method with a single input and decodes into `Output` generic type.
122+
///
123+
/// - Parameters:
124+
/// - method: The name of the method to execute.
125+
/// - input: A single `ValueConvertible` input.
126+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
127+
/// - Throws: An error if loading, execution or result conversion fails.
128+
func execute<Output: ValueSequenceConstructible>(_ method: String, _ input: ValueConvertible) throws -> Output {
129+
try execute(method, [input])
130+
}
131+
132+
/// Executes a specific method with no inputs and decodes into `Output` generic type.
133+
///
134+
/// - Parameter method: The name of the method to execute.
135+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
136+
/// - Throws: An error if loading, execution or result conversion fails.
137+
func execute<Output: ValueSequenceConstructible>(_ method: String) throws -> Output {
138+
try execute(method, [])
139+
}
140+
141+
/// Executes the "forward" method and decodes into `Output` generic type.
142+
///
143+
/// - Parameters:
144+
/// - inputs: An array of `ValueConvertible` inputs to pass to "forward".
145+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
146+
/// - Throws: An error if loading, execution or result conversion fails.
147+
func forward<Output: ValueSequenceConstructible>(_ inputs: [ValueConvertible]) throws -> Output {
148+
try execute("forward", inputs)
149+
}
150+
151+
/// Executes the "forward" method with variadic inputs and decodes into `Output` generic type.
152+
///
153+
/// - Parameters:
154+
/// - inputs: A variadic list of `ValueConvertible` inputs.
155+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
156+
/// - Throws: An error if loading, execution or result conversion fails.
157+
func forward<Output: ValueSequenceConstructible>(_ inputs: ValueConvertible...) throws -> Output {
158+
try forward(inputs)
159+
}
160+
161+
/// Executes the "forward" method with a single input and decodes into `Output` generic type.
162+
///
163+
/// - Parameters:
164+
/// - input: A single `ValueConvertible` to pass to "forward".
165+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
166+
/// - Throws: An error if loading, execution or result conversion fails.
167+
func forward<Output: ValueSequenceConstructible>(_ input: ValueConvertible) throws -> Output {
168+
try forward([input])
169+
}
170+
171+
/// Executes the "forward" method with no inputs and decodes into `Output` generic type.
172+
///
173+
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
174+
/// - Throws: An error if loading, execution or result conversion fails.
175+
func forward<Output: ValueSequenceConstructible>() throws -> Output {
176+
try execute("forward")
177+
}
178+
}

extension/apple/ExecuTorch/__tests__/ModuleTest.swift

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,34 @@ class ModuleTest: XCTestCase {
8181
XCTAssertEqual(outputs4?.first?.tensor(), Tensor([Float(5)]))
8282
}
8383

84-
func testmethodMetadata() throws {
84+
func testForwardReturnConversion() throws {
85+
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
86+
XCTFail("Couldn't find the model file")
87+
return
88+
}
89+
let module = Module(filePath: modelPath)
90+
let inputs: [Tensor<Float>] = [Tensor([1]), Tensor([1])]
91+
92+
let outputValues: [Value] = try module.forward(inputs)
93+
XCTAssertEqual(outputValues, [Value(Tensor<Float>([2]))])
94+
95+
let outputValue: Value = try module.forward(inputs)
96+
XCTAssertEqual(outputValue, Value(Tensor<Float>([2])))
97+
98+
let outputTensors: [Tensor<Float>] = try module.forward(inputs)
99+
XCTAssertEqual(outputTensors, [Tensor([2])])
100+
101+
let outputTensor: Tensor<Float> = try module.forward(Tensor<Float>([1]), Tensor<Float>([1]))
102+
XCTAssertEqual(outputTensor, Tensor([2]))
103+
104+
let scalars = (try module.forward(Tensor<Float>([1]), Tensor<Float>([1])) as Tensor<Float>).scalars()
105+
XCTAssertEqual(scalars, [2])
106+
107+
let scalars2 = try Tensor<Float>(module.forward(Tensor<Float>([1]), Tensor<Float>([1]))).scalars()
108+
XCTAssertEqual(scalars2, [2])
109+
}
110+
111+
func testMethodMetadata() throws {
85112
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
86113
XCTFail("Couldn't find the model file")
87114
return

0 commit comments

Comments
 (0)