Skip to content

Commit 448d4a5

Browse files
authored
double check two returns from refbackend (#135)
1 parent d209fe0 commit 448d4a5

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

mlir/extras/runtime/refbackend.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,17 @@ def make_return_consumer(kernel_func):
179179
# This function will be called KERNEL_NAME_capi_wrapper and will have a {llvm.emit_c_interface} attribute.
180180
# Note there might be other such functions in the final module (gpu-lower-to-nvvm-pipeline somehow also inserts some like this).
181181
def make_kernel_wrapper(kernel_func, return_consumer=None):
182-
c_api_compatible_types = [
183-
T.memref(element_type=t.element_type) if MemRefType.isinstance(t) else t
184-
for t in kernel_func.function_type.value.results
185-
]
186-
187182
input_types = kernel_func.function_type.value.inputs
188183

189184
@FuncOp.from_py_func(*input_types, name=f"{kernel_func.name.value}_capi_wrapper")
190185
def wrapper(*args, **_kwargs):
191186
results = CallOp(kernel_func, list(args)).results
192-
c_api_compatible_results = []
193-
for i, a in enumerate(results):
194-
if MemRefType.isinstance(a.type):
195-
a = cast(c_api_compatible_types[i], a)
196-
c_api_compatible_results.append(a)
197187
if return_consumer is not None:
188+
c_api_compatible_results = []
189+
for i, a in enumerate(results):
190+
if MemRefType.isinstance(a.type):
191+
a = cast(T.memref(element_type=a.type.element_type), a)
192+
c_api_compatible_results.append(a)
198193
CallOp(return_consumer, c_api_compatible_results)
199194

200195
wrapper_func_op = wrapper.func_op

tests/test_runtime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def memfoo(mem: ranked_memref_kxk_f32):
505505
T.f32(), index_cast(T.i32(), i)
506506
)
507507
res = yield it_mem
508-
return res
508+
return res, res
509509

510510
memfoo.emit()
511511

@@ -520,8 +520,9 @@ def memfoo(mem: ranked_memref_kxk_f32):
520520
A = np.ones((K, K)).astype(np.float32)
521521
AA = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(A)))
522522

523-
results = invoker.memfoo_capi_wrapper(AA)
523+
results, results1 = invoker.memfoo_capi_wrapper(AA)
524524
assert np.array_equal(np.diagonal(results), np.arange(1, K + 1))
525+
assert np.array_equal(np.diagonal(results1), np.arange(1, K + 1))
525526

526527

527528
def test_setting_memref_diagonal_no_iter(ctx: MLIRContext, backend: LLVMJITBackend):
@@ -535,6 +536,7 @@ def memfoo(mem: ranked_memref_kxk_f32):
535536
mem[i, i] = mem[i, i] + mem[i, i] * sitofp(T.f32(), index_cast(T.i32(), i))
536537

537538
memfoo.emit()
539+
print(ctx.module)
538540

539541
module = backend.compile(
540542
ctx.module,

0 commit comments

Comments
 (0)