Skip to content

Commit f7d7d25

Browse files
authored
Add parallel lower fix (#287)
* Add parallel lower fix * fixup
1 parent a8cbb3e commit f7d7d25

File tree

1 file changed

+119
-129
lines changed

1 file changed

+119
-129
lines changed

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 119 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -448,140 +448,130 @@ void ParallelLower::runOnOperation() {
448448
builder.eraseOp(launchOp);
449449
}
450450

451+
std::function<void(Operation * call, StringRef callee)> replace =
452+
[&](Operation *call, StringRef callee) {
453+
if (callee == "cudaMemcpy" || callee == "cudaMemcpyAsync") {
454+
OpBuilder bz(call);
455+
auto falsev = bz.create<ConstantIntOp>(call->getLoc(), false, 1);
456+
auto dst = call->getOperand(0);
457+
if (auto mt = dst.getType().dyn_cast<MemRefType>()) {
458+
dst = bz.create<polygeist::Memref2PointerOp>(
459+
call->getLoc(),
460+
LLVM::LLVMPointerType::get(mt.getElementType(),
461+
mt.getMemorySpaceAsInt()),
462+
dst);
463+
}
464+
auto src = call->getOperand(1);
465+
if (auto mt = src.getType().dyn_cast<MemRefType>()) {
466+
src = bz.create<polygeist::Memref2PointerOp>(
467+
call->getLoc(),
468+
LLVM::LLVMPointerType::get(mt.getElementType(),
469+
mt.getMemorySpaceAsInt()),
470+
src);
471+
}
472+
bz.create<LLVM::MemcpyOp>(call->getLoc(), dst, src,
473+
call->getOperand(2),
474+
/*isVolatile*/ falsev);
475+
call->replaceAllUsesWith(bz.create<ConstantIntOp>(
476+
call->getLoc(), 0, call->getResult(0).getType()));
477+
call->erase();
478+
} else if (callee == "cudaMemcpyToSymbol") {
479+
OpBuilder bz(call);
480+
auto falsev = bz.create<ConstantIntOp>(call->getLoc(), false, 1);
481+
auto dst = call->getOperand(0);
482+
if (auto mt = dst.getType().dyn_cast<MemRefType>()) {
483+
dst = bz.create<polygeist::Memref2PointerOp>(
484+
call->getLoc(),
485+
LLVM::LLVMPointerType::get(mt.getElementType(),
486+
mt.getMemorySpaceAsInt()),
487+
dst);
488+
}
489+
auto src = call->getOperand(1);
490+
if (auto mt = src.getType().dyn_cast<MemRefType>()) {
491+
src = bz.create<polygeist::Memref2PointerOp>(
492+
call->getLoc(),
493+
LLVM::LLVMPointerType::get(mt.getElementType(),
494+
mt.getMemorySpaceAsInt()),
495+
src);
496+
}
497+
bz.create<LLVM::MemcpyOp>(
498+
call->getLoc(),
499+
bz.create<LLVM::GEPOp>(call->getLoc(), dst.getType(), dst,
500+
std::vector<Value>({call->getOperand(3)})),
501+
src, call->getOperand(2),
502+
/*isVolatile*/ falsev);
503+
call->replaceAllUsesWith(bz.create<ConstantIntOp>(
504+
call->getLoc(), 0, call->getResult(0).getType()));
505+
call->erase();
506+
} else if (callee == "cudaMemset") {
507+
OpBuilder bz(call);
508+
auto falsev = bz.create<ConstantIntOp>(call->getLoc(), false, 1);
509+
bz.create<LLVM::MemsetOp>(call->getLoc(), call->getOperand(0),
510+
bz.create<TruncIOp>(call->getLoc(),
511+
bz.getI8Type(),
512+
call->getOperand(1)),
513+
call->getOperand(2),
514+
/*isVolatile*/ falsev);
515+
call->replaceAllUsesWith(bz.create<ConstantIntOp>(
516+
call->getLoc(), 0, call->getResult(0).getType()));
517+
call->erase();
518+
} else if (callee == "cudaMalloc" || callee == "cudaMallocHost") {
519+
OpBuilder bz(call);
520+
Value arg = call->getOperand(1);
521+
if (arg.getType().cast<IntegerType>().getWidth() < 64)
522+
arg =
523+
bz.create<arith::ExtUIOp>(call->getLoc(), bz.getI64Type(), arg);
524+
mlir::Value alloc =
525+
callMalloc(bz, getOperation(), call->getLoc(), arg);
526+
bz.create<LLVM::StoreOp>(call->getLoc(), alloc, call->getOperand(0));
527+
{
528+
auto retv = bz.create<ConstantIntOp>(
529+
call->getLoc(), 0,
530+
call->getResult(0).getType().cast<IntegerType>().getWidth());
531+
Value vals[] = {retv};
532+
call->replaceAllUsesWith(ArrayRef<Value>(vals));
533+
call->erase();
534+
}
535+
} else if (callee == "cudaFree" || callee == "cudaFreeHost") {
536+
auto mf = GetOrCreateFreeFunction(getOperation());
537+
OpBuilder bz(call);
538+
Value args[] = {call->getOperand(0)};
539+
bz.create<mlir::LLVM::CallOp>(call->getLoc(), mf, args);
540+
{
541+
auto retv = bz.create<ConstantIntOp>(
542+
call->getLoc(), 0,
543+
call->getResult(0).getType().cast<IntegerType>().getWidth());
544+
Value vals[] = {retv};
545+
call->replaceAllUsesWith(ArrayRef<Value>(vals));
546+
call->erase();
547+
}
548+
} else if (callee == "cudaDeviceSynchronize") {
549+
OpBuilder bz(call);
550+
auto retv = bz.create<ConstantIntOp>(
551+
call->getLoc(), 0,
552+
call->getResult(0).getType().cast<IntegerType>().getWidth());
553+
Value vals[] = {retv};
554+
call->replaceAllUsesWith(ArrayRef<Value>(vals));
555+
call->erase();
556+
} else if (callee == "cudaGetLastError") {
557+
OpBuilder bz(call);
558+
auto retv = bz.create<ConstantIntOp>(
559+
call->getLoc(), 0,
560+
call->getResult(0).getType().cast<IntegerType>().getWidth());
561+
Value vals[] = {retv};
562+
call->replaceAllUsesWith(ArrayRef<Value>(vals));
563+
call->erase();
564+
}
565+
};
566+
451567
getOperation().walk([&](LLVM::CallOp call) {
452568
if (!call.getCallee())
453569
return;
454-
if (call.getCallee().value() == "cudaMemcpy" ||
455-
call.getCallee().value() == "cudaMemcpyAsync") {
456-
OpBuilder bz(call);
457-
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
458-
bz.create<LLVM::MemcpyOp>(call.getLoc(), call.getOperand(0),
459-
call.getOperand(1), call.getOperand(2),
460-
/*isVolatile*/ falsev);
461-
call.replaceAllUsesWith(
462-
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
463-
call.erase();
464-
} else if (call.getCallee().value() == "cudaMemcpyToSymbol") {
465-
OpBuilder bz(call);
466-
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
467-
bz.create<LLVM::MemcpyOp>(
468-
call.getLoc(),
469-
bz.create<LLVM::GEPOp>(call.getLoc(), call.getOperand(0).getType(),
470-
call.getOperand(0),
471-
std::vector<Value>({call.getOperand(3)})),
472-
call.getOperand(1), call.getOperand(2),
473-
/*isVolatile*/ falsev);
474-
call.replaceAllUsesWith(
475-
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
476-
call.erase();
477-
} else if (call.getCallee().value() == "cudaMemset") {
478-
OpBuilder bz(call);
479-
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
480-
bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0),
481-
bz.create<TruncIOp>(call.getLoc(),
482-
bz.getI8Type(),
483-
call.getOperand(1)),
484-
call.getOperand(2),
485-
/*isVolatile*/ falsev);
486-
call.replaceAllUsesWith(
487-
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
488-
call.erase();
489-
} else if (call.getCallee().value() == "cudaMalloc" ||
490-
call.getCallee().value() == "cudaMallocHost") {
491-
OpBuilder bz(call);
492-
Value arg = call.getOperand(1);
493-
if (arg.getType().cast<IntegerType>().getWidth() < 64)
494-
arg = bz.create<arith::ExtUIOp>(call.getLoc(), bz.getI64Type(), arg);
495-
mlir::Value alloc = callMalloc(bz, getOperation(), call.getLoc(), arg);
496-
bz.create<LLVM::StoreOp>(call.getLoc(), alloc, call.getOperand(0));
497-
{
498-
auto retv = bz.create<ConstantIntOp>(
499-
call.getLoc(), 0,
500-
call.getResult().getType().cast<IntegerType>().getWidth());
501-
Value vals[] = {retv};
502-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
503-
call.erase();
504-
}
505-
} else if (call.getCallee().value() == "cudaFree" ||
506-
call.getCallee().value() == "cudaFreeHost") {
507-
auto mf = GetOrCreateFreeFunction(getOperation());
508-
OpBuilder bz(call);
509-
Value args[] = {call.getOperand(0)};
510-
bz.create<mlir::LLVM::CallOp>(call.getLoc(), mf, args);
511-
{
512-
auto retv = bz.create<ConstantIntOp>(
513-
call.getLoc(), 0,
514-
call.getResult().getType().cast<IntegerType>().getWidth());
515-
Value vals[] = {retv};
516-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
517-
call.erase();
518-
}
519-
} else if (call.getCallee().value() == "cudaDeviceSynchronize") {
520-
OpBuilder bz(call);
521-
auto retv = bz.create<ConstantIntOp>(
522-
call.getLoc(), 0,
523-
call.getResult().getType().cast<IntegerType>().getWidth());
524-
Value vals[] = {retv};
525-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
526-
call.erase();
527-
} else if (call.getCallee().value() == "cudaGetLastError") {
528-
OpBuilder bz(call);
529-
auto retv = bz.create<ConstantIntOp>(
530-
call.getLoc(), 0,
531-
call.getResult().getType().cast<IntegerType>().getWidth());
532-
Value vals[] = {retv};
533-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
534-
call.erase();
535-
}
536-
});
537-
getOperation().walk([&](CallOp call) {
538-
if (call.getCallee() == "cudaDeviceSynchronize") {
539-
OpBuilder bz(call);
540-
auto retv = bz.create<ConstantIntOp>(
541-
call.getLoc(), 0,
542-
call.getResult(0).getType().cast<IntegerType>().getWidth());
543-
Value vals[] = {retv};
544-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
545-
call.erase();
546-
} else if (call.getCallee() == "cudaMemcpyToSymbol") {
547-
OpBuilder bz(call);
548-
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
549-
auto dst = call.getOperand(0);
550-
if (auto mt = dst.getType().cast<MemRefType>()) {
551-
dst = bz.create<polygeist::Memref2PointerOp>(
552-
call.getLoc(),
553-
LLVM::LLVMPointerType::get(mt.getElementType(),
554-
mt.getMemorySpaceAsInt()),
555-
dst);
556-
}
557-
auto src = call.getOperand(1);
558-
if (auto mt = src.getType().cast<MemRefType>()) {
559-
src = bz.create<polygeist::Memref2PointerOp>(
560-
call.getLoc(),
561-
LLVM::LLVMPointerType::get(mt.getElementType(),
562-
mt.getMemorySpaceAsInt()),
563-
src);
564-
}
565-
bz.create<LLVM::MemcpyOp>(
566-
call.getLoc(),
567-
bz.create<LLVM::GEPOp>(call.getLoc(), dst.getType(), dst,
568-
std::vector<Value>({call.getOperand(3)})),
569-
src, call.getOperand(2),
570-
/*isVolatile*/ falsev);
571-
call.replaceAllUsesWith(
572-
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
573-
call.erase();
574-
} else if (call.getCallee() == "cudaGetLastError") {
575-
OpBuilder bz(call);
576-
auto retv = bz.create<ConstantIntOp>(
577-
call.getLoc(), 0,
578-
call.getResult(0).getType().cast<IntegerType>().getWidth());
579-
Value vals[] = {retv};
580-
call.replaceAllUsesWith(ArrayRef<Value>(vals));
581-
call.erase();
582-
}
570+
replace(call, *call.getCallee());
583571
});
584572

573+
getOperation().walk([&](CallOp call) { replace(call, call.getCallee()); });
574+
585575
// Fold the copy memtype cast
586576
{
587577
mlir::RewritePatternSet rpl(getOperation()->getContext());

0 commit comments

Comments
 (0)