@@ -252,8 +252,18 @@ ValueCategory MLIRScanner::CallHelper(
252
252
auto l0 = Visit (CU->getConfig ()->getArg (0 ));
253
253
assert (l0.isReference );
254
254
mlir::Value blocks[3 ];
255
+ mlir::Value val = l0.val ;
256
+ if (auto MT = val.getType ().dyn_cast <MemRefType>()) {
257
+ if (MT.getElementType ().isa <LLVM::LLVMStructType>() &&
258
+ MT.getShape ().size () == 1 ) {
259
+ val = builder.create <polygeist::Memref2PointerOp>(
260
+ loc,
261
+ LLVM::LLVMPointerType::get (MT.getElementType (),
262
+ MT.getMemorySpaceAsInt ()),
263
+ val);
264
+ }
265
+ }
255
266
for (int i = 0 ; i < 3 ; i++) {
256
- mlir::Value val = l0.val ;
257
267
if (auto MT = val.getType ().dyn_cast <MemRefType>()) {
258
268
mlir::Value idx[] = {getConstantIndex (0 ), getConstantIndex (i)};
259
269
assert (MT.getShape ().size () == 2 );
@@ -278,8 +288,18 @@ ValueCategory MLIRScanner::CallHelper(
278
288
auto t0 = Visit (CU->getConfig ()->getArg (1 ));
279
289
assert (t0.isReference );
280
290
mlir::Value threads[3 ];
291
+ val = t0.val ;
292
+ if (auto MT = val.getType ().dyn_cast <MemRefType>()) {
293
+ if (MT.getElementType ().isa <LLVM::LLVMStructType>() &&
294
+ MT.getShape ().size () == 1 ) {
295
+ val = builder.create <polygeist::Memref2PointerOp>(
296
+ loc,
297
+ LLVM::LLVMPointerType::get (MT.getElementType (),
298
+ MT.getMemorySpaceAsInt ()),
299
+ val);
300
+ }
301
+ }
281
302
for (int i = 0 ; i < 3 ; i++) {
282
- mlir::Value val = t0.val ;
283
303
if (auto MT = val.getType ().dyn_cast <MemRefType>()) {
284
304
mlir::Value idx[] = {getConstantIndex (0 ), getConstantIndex (i)};
285
305
assert (MT.getShape ().size () == 2 );
0 commit comments