@@ -179,22 +179,17 @@ def make_return_consumer(kernel_func):
179
179
# This function will be called KERNEL_NAME_capi_wrapper and will have a {llvm.emit_c_interface} attribute.
180
180
# Note there might be other such functions in the final module (gpu-lower-to-nvvm-pipeline somehow also inserts some like this).
181
181
def make_kernel_wrapper (kernel_func , return_consumer = None ):
182
- c_api_compatible_types = [
183
- T .memref (element_type = t .element_type ) if MemRefType .isinstance (t ) else t
184
- for t in kernel_func .function_type .value .results
185
- ]
186
-
187
182
input_types = kernel_func .function_type .value .inputs
188
183
189
184
@FuncOp .from_py_func (* input_types , name = f"{ kernel_func .name .value } _capi_wrapper" )
190
185
def wrapper (* args , ** _kwargs ):
191
186
results = CallOp (kernel_func , list (args )).results
192
- c_api_compatible_results = []
193
- for i , a in enumerate (results ):
194
- if MemRefType .isinstance (a .type ):
195
- a = cast (c_api_compatible_types [i ], a )
196
- c_api_compatible_results .append (a )
197
187
if return_consumer is not None :
188
+ c_api_compatible_results = []
189
+ for i , a in enumerate (results ):
190
+ if MemRefType .isinstance (a .type ):
191
+ a = cast (T .memref (element_type = a .type .element_type ), a )
192
+ c_api_compatible_results .append (a )
198
193
CallOp (return_consumer , c_api_compatible_results )
199
194
200
195
wrapper_func_op = wrapper .func_op
0 commit comments