Skip to content

Commit c235e45

Browse files
authored
Misc LULESH fixes (#310)
* Misc LULESH fixes * Fix glob * fixes
1 parent 911b6a5 commit c235e45

File tree

15 files changed

+459
-115
lines changed

15 files changed

+459
-115
lines changed

lib/polygeist/Ops.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,14 +1257,17 @@ class Memref2Pointer2MemrefCast final
12571257
auto src = op.getSource().getDefiningOp<Memref2PointerOp>();
12581258
if (!src)
12591259
return failure();
1260-
if (src.getSource().getType().cast<MemRefType>().getShape().size() !=
1261-
op.getType().cast<MemRefType>().getShape().size())
1260+
auto smt = src.getSource().getType().cast<MemRefType>();
1261+
auto omt = op.getType().cast<MemRefType>();
1262+
if (smt.getShape().size() != omt.getShape().size())
12621263
return failure();
1263-
if (src.getSource().getType().cast<MemRefType>().getElementType() !=
1264-
op.getType().cast<MemRefType>().getElementType())
1264+
for (int i = 1; i < smt.getShape().size(); i++) {
1265+
if (smt.getShape()[i] != omt.getShape()[i])
1266+
return failure();
1267+
}
1268+
if (smt.getElementType() != omt.getElementType())
12651269
return failure();
1266-
if (src.getSource().getType().cast<MemRefType>().getMemorySpace() !=
1267-
op.getType().cast<MemRefType>().getMemorySpace())
1270+
if (smt.getMemorySpace() != omt.getMemorySpace())
12681271
return failure();
12691272

12701273
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(),
@@ -1551,12 +1554,30 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
15511554
return nullptr;
15521555
}
15531556

1557+
/// Simplify memref2pointer(pointer2memref(x)) to cast(x)
1558+
class Memref2PointerBitCast final : public OpRewritePattern<LLVM::BitcastOp> {
1559+
public:
1560+
using OpRewritePattern<LLVM::BitcastOp>::OpRewritePattern;
1561+
1562+
LogicalResult matchAndRewrite(LLVM::BitcastOp op,
1563+
PatternRewriter &rewriter) const override {
1564+
auto src = op.getOperand().getDefiningOp<Memref2PointerOp>();
1565+
if (!src)
1566+
return failure();
1567+
1568+
rewriter.replaceOpWithNewOp<Memref2PointerOp>(op, op.getType(),
1569+
src.getOperand());
1570+
return success();
1571+
}
1572+
};
1573+
15541574
void Memref2PointerOp::getCanonicalizationPatterns(RewritePatternSet &results,
15551575
MLIRContext *context) {
1556-
results.insert<Memref2Pointer2MemrefCast, Memref2PointerIndex,
1557-
SetSimplification<LLVM::MemsetOp>,
1558-
CopySimplification<LLVM::MemcpyOp>,
1559-
CopySimplification<LLVM::MemmoveOp>>(context);
1576+
results.insert<
1577+
Memref2Pointer2MemrefCast, Memref2PointerIndex, Memref2PointerBitCast,
1578+
1579+
SetSimplification<LLVM::MemsetOp>, CopySimplification<LLVM::MemcpyOp>,
1580+
CopySimplification<LLVM::MemmoveOp>>(context);
15601581
}
15611582

15621583
/// Simplify cast(pointer2memref(x)) to pointer2memref(x)

lib/polygeist/Passes/Mem2Reg.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,12 @@ bool Mem2Reg::forwardStoreToLoad(
10871087
mlir::Location loc = AI.getLoc();
10881088
std::set<mlir::Operation *> allStoreOps;
10891089

1090+
Type elType;
1091+
if (auto MT = AI.getType().dyn_cast<MemRefType>())
1092+
elType = MT.getElementType();
1093+
else
1094+
elType = AI.getType().cast<LLVM::LLVMPointerType>().getElementType();
1095+
10901096
std::deque<std::pair<mlir::Value, /*indexed*/ bool>> list = {{AI, false}};
10911097

10921098
SmallPtrSet<Operation *, 4> AliasingStoreOperations;
@@ -1147,16 +1153,20 @@ bool Mem2Reg::forwardStoreToLoad(
11471153
if (!modified &&
11481154
matchesIndices(loadOp.getIndices(), idx) == Match::Exact) {
11491155
subType = loadOp.getType();
1150-
loadOps.insert(loadOp);
1151-
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1156+
if (subType == elType) {
1157+
loadOps.insert(loadOp);
1158+
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1159+
}
11521160
}
11531161
continue;
11541162
}
11551163
if (auto loadOp = dyn_cast<mlir::LLVM::LoadOp>(user)) {
11561164
if (!modified) {
11571165
subType = loadOp.getType();
1158-
loadOps.insert(loadOp);
1159-
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1166+
if (subType == elType) {
1167+
loadOps.insert(loadOp);
1168+
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1169+
}
11601170
}
11611171
continue;
11621172
}
@@ -1165,8 +1175,10 @@ bool Mem2Reg::forwardStoreToLoad(
11651175
matchesIndices(loadOp.getAffineMapAttr().getValue(),
11661176
loadOp.getMapOperands(), idx) == Match::Exact) {
11671177
subType = loadOp.getType();
1168-
loadOps.insert(loadOp);
1169-
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1178+
if (subType == elType) {
1179+
loadOps.insert(loadOp);
1180+
LLVM_DEBUG(llvm::dbgs() << "Matching Load: " << loadOp << "\n");
1181+
}
11701182
}
11711183
continue;
11721184
}
@@ -1384,12 +1396,6 @@ bool Mem2Reg::forwardStoreToLoad(
13841396
}
13851397
}
13861398

1387-
Type elType;
1388-
if (auto MT = AI.getType().dyn_cast<MemRefType>())
1389-
elType = MT.getElementType();
1390-
else
1391-
elType = AI.getType().cast<LLVM::LLVMPointerType>().getElementType();
1392-
13931399
ReplacementHandler metaMap(elType);
13941400

13951401
// Last value stored in an individual block and the operation which stored it

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,8 @@ void ParallelLower::runOnOperation() {
561561
Value vals[] = {retv};
562562
call->replaceAllUsesWith(ArrayRef<Value>(vals));
563563
call->erase();
564-
} else if (callee == "cudaGetLastError") {
564+
} else if (callee == "cudaGetLastError" ||
565+
callee == "cudaPeekAtLastError") {
565566
OpBuilder bz(call);
566567
auto retv = bz.create<ConstantIntOp>(
567568
call->getLoc(), 0,

tools/cgeist/Lib/CGCall.cc

Lines changed: 173 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
534534
}
535535
}
536536

537-
auto getLLVM = [&](Expr *E) -> mlir::Value {
537+
auto getLLVM = [&](Expr *E, bool isRef = false) -> mlir::Value {
538538
auto sub = Visit(E);
539539
if (!sub.val) {
540540
expr->dump();
@@ -564,23 +564,46 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
564564
auto shape = std::vector<int64_t>(mt.getShape());
565565
assert(shape.size() == 2);
566566

567-
OpBuilder abuilder(builder.getContext());
568-
abuilder.setInsertionPointToStart(allocationScope);
569-
auto one = abuilder.create<ConstantIntOp>(loc, 1, 64);
570-
auto alloc = abuilder.create<mlir::LLVM::AllocaOp>(
571-
loc,
567+
auto PT =
572568
LLVM::LLVMPointerType::get(Glob.typeTranslator.translateType(
573569
anonymize(getLLVMType(E->getType()))),
574-
0),
575-
one, 0);
576-
ValueCategory(alloc, /*isRef*/ true)
577-
.store(loc, builder, sub, /*isArray*/ isArray);
578-
sub = ValueCategory(alloc, /*isRef*/ true);
570+
0);
571+
if (true) {
572+
sub = ValueCategory(
573+
builder.create<polygeist::Memref2PointerOp>(loc, PT, sub.val),
574+
sub.isReference);
575+
} else {
576+
OpBuilder abuilder(builder.getContext());
577+
abuilder.setInsertionPointToStart(allocationScope);
578+
auto one = abuilder.create<ConstantIntOp>(loc, 1, 64);
579+
auto alloc = abuilder.create<mlir::LLVM::AllocaOp>(loc, PT, one, 0);
580+
ValueCategory(alloc, /*isRef*/ true)
581+
.store(loc, builder, sub, /*isArray*/ isArray);
582+
sub = ValueCategory(alloc, /*isRef*/ true);
583+
}
584+
}
585+
mlir::Value val;
586+
clang::QualType ct;
587+
if (!isRef) {
588+
val = sub.getValue(loc, builder);
589+
ct = E->getType();
590+
} else {
591+
if (!sub.isReference) {
592+
OpBuilder abuilder(builder.getContext());
593+
abuilder.setInsertionPointToStart(allocationScope);
594+
auto one = abuilder.create<ConstantIntOp>(loc, 1, 64);
595+
auto alloc = abuilder.create<mlir::LLVM::AllocaOp>(
596+
loc, LLVM::LLVMPointerType::get(sub.val.getType()), one, 0);
597+
ValueCategory(alloc, /*isRef*/ true)
598+
.store(loc, builder, sub, /*isArray*/ isArray);
599+
sub = ValueCategory(alloc, /*isRef*/ true);
600+
}
601+
assert(sub.isReference);
602+
val = sub.val;
603+
ct = Glob.CGM.getContext().getLValueReferenceType(E->getType());
579604
}
580-
auto val = sub.getValue(loc, builder);
581605
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
582-
auto nt = Glob.typeTranslator
583-
.translateType(anonymize(getLLVMType(E->getType())))
606+
auto nt = Glob.typeTranslator.translateType(anonymize(getLLVMType(ct)))
584607
.cast<LLVM::LLVMPointerType>();
585608
val = builder.create<polygeist::Memref2PointerOp>(loc, nt, val);
586609
}
@@ -1483,7 +1506,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
14831506

14841507
std::vector<mlir::Value> args;
14851508
for (auto *a : expr->arguments()) {
1486-
args.push_back(getLLVM(a));
1509+
args.push_back(getLLVM(a, /*isRef*/ false));
14871510
}
14881511
mlir::Value called;
14891512

@@ -1492,7 +1515,8 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
14921515
called = builder.create<mlir::LLVM::CallOp>(loc, strcmpF, args)
14931516
.getResult();
14941517
} else {
1495-
args.insert(args.begin(), getLLVM(expr->getCallee()));
1518+
args.insert(args.begin(),
1519+
getLLVM(expr->getCallee(), /*isRef*/ false));
14961520
SmallVector<mlir::Type> RTs = {Glob.typeTranslator.translateType(
14971521
anonymize(getLLVMType(expr->getType())))};
14981522
if (RTs[0].isa<LLVM::LLVMVoidType>())
@@ -1509,31 +1533,154 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
15091533
if (!callee || callee->isVariadic()) {
15101534
bool isReference = expr->isLValue() || expr->isXValue();
15111535
std::vector<mlir::Value> args;
1512-
for (auto *a : expr->arguments()) {
1513-
args.push_back(getLLVM(a));
1514-
}
15151536
mlir::Value called;
15161537
if (callee) {
15171538
auto strcmpF = Glob.GetOrCreateLLVMFunction(callee);
1539+
std::vector<clang::QualType> types;
1540+
if (auto CC = dyn_cast<CXXMethodDecl>(callee)) {
1541+
types.push_back(CC->getThisType());
1542+
}
1543+
for (auto parm : callee->parameters()) {
1544+
types.push_back(parm->getOriginalType());
1545+
}
1546+
int i = 0;
1547+
for (auto *a : expr->arguments()) {
1548+
bool isRef = false;
1549+
if (i < types.size())
1550+
isRef = types[i]->isReferenceType();
1551+
i++;
1552+
args.push_back(getLLVM(a, isRef));
1553+
}
15181554
called =
15191555
builder.create<mlir::LLVM::CallOp>(loc, strcmpF, args).getResult();
15201556
} else {
1521-
args.insert(args.begin(), getLLVM(expr->getCallee()));
1557+
mlir::Value fn = Visit(expr->getCallee()).getValue(loc, builder);
1558+
if (auto MT = fn.getType().dyn_cast<MemRefType>()) {
1559+
fn = builder.create<polygeist::Memref2PointerOp>(
1560+
loc, LLVM::LLVMPointerType::get(MT.getElementType(), 0), fn);
1561+
}
1562+
auto PTF = fn.getType()
1563+
.cast<LLVM::LLVMPointerType>()
1564+
.getElementType()
1565+
.cast<LLVM::LLVMFunctionType>();
1566+
SmallVector<mlir::Type, 1> argtys;
1567+
bool needsChange = false;
1568+
for (auto FT : PTF.getParams()) {
1569+
if (auto mt = FT.dyn_cast<MemRefType>()) {
1570+
argtys.push_back(LLVM::LLVMPointerType::get(mt.getElementType(), 0));
1571+
needsChange = true;
1572+
} else
1573+
argtys.push_back(FT);
1574+
}
1575+
auto rt = PTF.getReturnType();
1576+
if (auto mt = rt.dyn_cast<MemRefType>()) {
1577+
rt = LLVM::LLVMPointerType::get(mt.getElementType(), 0);
1578+
needsChange = true;
1579+
}
1580+
if (needsChange)
1581+
fn = builder.create<LLVM::BitcastOp>(
1582+
loc,
1583+
LLVM::LLVMPointerType::get(
1584+
LLVM::LLVMFunctionType::get(rt, argtys, PTF.isVarArg()), 0),
1585+
fn);
1586+
1587+
args.push_back(fn);
15221588
auto CT = expr->getType();
1523-
if (isReference)
1524-
CT = Glob.CGM.getContext().getLValueReferenceType(CT);
1525-
SmallVector<mlir::Type> RTs = {
1526-
Glob.typeTranslator.translateType(anonymize(getLLVMType(CT)))};
1589+
// if (isReference)
1590+
// CT = Glob.CGM.getContext().getLValueReferenceType(CT);
1591+
SmallVector<mlir::Type> RTs = {rt};
1592+
// getMLIRType(CT)};
15271593

15281594
auto ft = args[0]
15291595
.getType()
15301596
.cast<LLVM::LLVMPointerType>()
15311597
.getElementType()
15321598
.cast<LLVM::LLVMFunctionType>();
1533-
assert(RTs[0] == ft.getReturnType());
1534-
if (RTs[0].isa<LLVM::LLVMVoidType>())
1599+
auto ETy = expr->getCallee()->getType()->getUnqualifiedDesugaredType();
1600+
ETy = cast<clang::PointerType>(ETy)
1601+
->getPointeeType()
1602+
->getUnqualifiedDesugaredType();
1603+
auto CFT = dyn_cast<clang::FunctionProtoType>(ETy);
1604+
std::vector<clang::QualType> types;
1605+
if (CFT) {
1606+
for (auto t : CFT->getParamTypes())
1607+
types.push_back(t);
1608+
} else {
1609+
assert(isa<clang::FunctionNoProtoType>(ETy));
1610+
}
1611+
1612+
auto ETy2 = ETy->getCanonicalTypeUnqualified();
1613+
1614+
const clang::CodeGen::CGFunctionInfo *FI;
1615+
if (const FunctionProtoType *FPT = dyn_cast<FunctionProtoType>(ETy2)) {
1616+
FI = &Glob.CGM.getTypes().arrangeFreeFunctionType(
1617+
CanQual<FunctionProtoType>::CreateUnsafe(QualType(FPT, 0)));
1618+
} else {
1619+
const FunctionNoProtoType *FNPT = cast<FunctionNoProtoType>(ETy2);
1620+
FI = &Glob.CGM.getTypes().arrangeFreeFunctionType(
1621+
CanQual<FunctionNoProtoType>::CreateUnsafe(QualType(FNPT, 0)));
1622+
}
1623+
1624+
int i = 0;
1625+
for (auto *a : expr->arguments()) {
1626+
bool isRef = false;
1627+
bool isArray = false;
1628+
if (i < types.size()) {
1629+
isRef = types[i]->isReferenceType();
1630+
// auto inf = FI->arguments()[i].info;
1631+
// isRef |= inf.isIndirect();
1632+
Glob.getMLIRType(types[i], &isArray);
1633+
isRef |= isArray;
1634+
}
1635+
1636+
auto sub = Visit(a);
1637+
mlir::Value v;
1638+
if (isRef) {
1639+
if (!sub.isReference) {
1640+
OpBuilder abuilder(builder.getContext());
1641+
abuilder.setInsertionPointToStart(allocationScope);
1642+
auto one = abuilder.create<ConstantIntOp>(loc, 1, 64);
1643+
auto alloc = abuilder.create<mlir::LLVM::AllocaOp>(
1644+
loc, LLVM::LLVMPointerType::get(sub.val.getType()), one, 0);
1645+
ValueCategory(alloc, /*isRef*/ true)
1646+
.store(loc, builder, sub, /*isArray*/ false);
1647+
sub = ValueCategory(alloc, /*isRef*/ true);
1648+
}
1649+
assert(sub.isReference);
1650+
v = sub.val;
1651+
} else {
1652+
v = sub.getValue(loc, builder);
1653+
}
1654+
if (i < FI->arg_size()) {
1655+
// TODO expand full calling conv
1656+
/*
1657+
auto inf = FI->arguments()[i].info;
1658+
if (inf.isIgnore() || inf.isInAlloca()) {
1659+
i++;
1660+
continue;
1661+
}
1662+
if (inf.isExpand()) {
1663+
i++;
1664+
continue;
1665+
}
1666+
*/
1667+
}
1668+
i++;
1669+
if (auto mt = v.getType().dyn_cast<MemRefType>()) {
1670+
v = builder.create<polygeist::Memref2PointerOp>(
1671+
loc, LLVM::LLVMPointerType::get(mt.getElementType(), 0), v);
1672+
}
1673+
args.push_back(v);
1674+
}
1675+
if (RTs[0].isa<mlir::NoneType>() || RTs[0].isa<LLVM::LLVMVoidType>())
15351676
RTs.clear();
1677+
else
1678+
assert(RTs[0] == ft.getReturnType());
15361679
called = builder.create<mlir::LLVM::CallOp>(loc, RTs, args).getResult();
1680+
if (PTF.getReturnType() != ft.getReturnType()) {
1681+
called = builder.create<polygeist::Pointer2MemrefOp>(
1682+
loc, PTF.getReturnType(), called);
1683+
}
15371684
}
15381685
if (isReference) {
15391686
if (!(called.getType().isa<LLVM::LLVMPointerType>() ||

tools/cgeist/Lib/CGStmt.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,8 @@ ValueCategory MLIRScanner::VisitDeclStmt(clang::DeclStmt *decl) {
10191019
if (auto *vd = dyn_cast<VarDecl>(sub)) {
10201020
VisitVarDecl(vd);
10211021
} else if (isa<TypeAliasDecl, RecordDecl, StaticAssertDecl, TypedefDecl,
1022-
UsingDecl, UsingDirectiveDecl>(sub)) {
1022+
UsingDecl, UsingDirectiveDecl, EnumConstantDecl, EnumDecl>(
1023+
sub)) {
10231024
} else {
10241025
emitError(getMLIRLocation(decl->getBeginLoc()))
10251026
<< " + visiting unknonwn sub decl stmt\n";

0 commit comments

Comments
 (0)