Skip to content

Commit 6c0ad40

Browse files
authored
Fix cuda memset lower (#289)
1 parent 859ed1f commit 6c0ad40

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,15 @@ void ParallelLower::runOnOperation() {
506506
} else if (callee == "cudaMemset") {
507507
OpBuilder bz(call);
508508
auto falsev = bz.create<ConstantIntOp>(call->getLoc(), false, 1);
509-
bz.create<LLVM::MemsetOp>(call->getLoc(), call->getOperand(0),
509+
auto dst = call->getOperand(0);
510+
if (auto mt = dst.getType().dyn_cast<MemRefType>()) {
511+
dst = bz.create<polygeist::Memref2PointerOp>(
512+
call->getLoc(),
513+
LLVM::LLVMPointerType::get(mt.getElementType(),
514+
mt.getMemorySpaceAsInt()),
515+
dst);
516+
}
517+
bz.create<LLVM::MemsetOp>(call->getLoc(), dst,
510518
bz.create<TruncIOp>(call->getLoc(),
511519
bz.getI8Type(),
512520
call->getOperand(1)),

0 commit comments

Comments
 (0)