Skip to content

Commit 35cf8d1

Browse files
Add support for two return values
This commit adds support for two return values of type memref f32 and i64. Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent ca662dc commit 35cf8d1

File tree

5 files changed

+51
-4
lines changed

5 files changed

+51
-4
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,3 +998,19 @@ def forward(self):
998998
def TensorOpaqueLiteralModule_basic(module, tu: TestUtils):
999999
module.forward()
10001000

1001+
class ReturnTwoTensorF32I64(torch.nn.Module):
1002+
def __init__(self):
1003+
super().__init__()
1004+
1005+
@export
1006+
@annotate_args([
1007+
None,
1008+
([-1, -1], torch.float32, True),
1009+
([-1, -1], torch.int64, True),
1010+
])
1011+
def forward(self, a, b):
1012+
return a, b
1013+
1014+
@register_test_case(module_factory=lambda: ReturnTwoTensorF32I64())
1015+
def ReturnTwoTensorF32I64_basic(module, tu: TestUtils):
1016+
module.forward(tu.rand(2, 3), torch.randint(5, (2, 3)))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,5 @@
4949
"SqueezeDimModule_static",
5050
"SqueezeDimModule_identity",
5151
"SqueezeDimModule_unitDim",
52+
"ReturnTwoTensorF32I64_basic",
5253
}

lib/RefBackend/RefBackend.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ static LogicalResult mungeFunction(
163163
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
164164
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
165165
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
166-
op.emitError(
167-
"must have one return value of memref types or scalar types "
168-
"of i32, i64, f32, f64, i1, or three return values of memref f32");
166+
op.emitError("must have one return value of memref types or scalar types "
167+
"of i32, i64, f32, f64, i1, or two return values of memref "
168+
"f32 and i64, or three return values of memref f32");
169169
isSupported = false;
170170
}
171171

@@ -194,7 +194,8 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
194194
Type f64 = b.getF64Type();
195195

196196
SmallVector<TypeRange> supportedReturnTypes = {
197-
mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
197+
mri1, mri32, mri64, mrf32, mrf64,
198+
i64, f32, f64, {mrf32, mri64}, {mrf32, mrf32, mrf32}};
198199

199200
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
200201
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));

python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def consume_return_f32(a):
6565
def consume_return_f64(a):
6666
self.result = a
6767

68+
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
69+
ctypes.POINTER(UnrankedMemRefDescriptor))
70+
def consume_return_mrf32_mri64(arg0, arg1):
71+
self.result = unranked_memref_to_numpy(
72+
arg0, np.float32), unranked_memref_to_numpy(
73+
arg1,
74+
np.int64)
75+
6876
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
6977
ctypes.POINTER(UnrankedMemRefDescriptor),
7078
ctypes.POINTER(UnrankedMemRefDescriptor))
@@ -98,6 +106,10 @@ def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
98106
self.ee.register_runtime("refbackend_consume_func_return_f64",
99107
consume_return_f64)
100108

109+
self.ee.register_runtime(
110+
"refbackend_consume_func_return_mrf32_mri64",
111+
consume_return_mrf32_mri64)
112+
101113
self.ee.register_runtime(
102114
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
103115
consume_return_mrf32_mrf32_mrf32)

test/RefBackend/munge-calling-conventions.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,20 @@ func @elemental_type(%arg0: memref<i64>) -> i64 {
5353
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
5454
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
5555
}
56+
57+
// -----
58+
59+
// CHECK-LABEL: func @two_return_values(
60+
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xi64>)
61+
// CHECK-SAME: attributes {llvm.emit_c_interface} {
62+
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
63+
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xi64> to memref<?xi64>
64+
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
65+
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xi64> to memref<*xi64>
66+
// CHECK: call @refbackend_consume_func_return_mrf32_mri64(%[[RET0]], %[[RET1]])
67+
// CHECK-SAME: : (memref<*xf32>, memref<*xi64>) -> ()
68+
// CHECK: return
69+
70+
func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
71+
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
72+
}

0 commit comments

Comments
 (0)