Skip to content

Commit 11ed9b3

Browse files
ftynsewsmosesivanradanov
authored
Alternative lowering for memrefs (#276)
* WIP: alternative lowering for memrefs Introduce an alternative lowering scheme for memref types. A n-D memref is always lowered to a pointer to an (n-1)-D nested LLVM array type, a 0-D and a 1-D memref is lowered to a bare pointer. This applies both at function boundaries and within functions, unlike the MLIR's bare pointer calling convention. This lowering is only possible for memrefs with static shapes except for the outermost dimension, which is consistent with the kinds of nested array structure C allows and therefore suitable for a C frontend. However, it only supports a subset of memref dialect operations that are possible to implement using this data structure. For example, view, subview and transpose cannot be implemented as they result in non-row-major layouts unexpressible with C nested arrays. This simplifies the lowering and the generated IR for all supported memref operations and removes the need for the ABI duplication at the function boundary: memrefs are lowered to the same LLVM IR as equivalent nested arrays. * Enable c memref * Fix lowering c abi * Fix compile issues * Fix format * Add unit tests and fix the bug discovered in the process nD dynamic allocation was not taking into account the static part of the shape leading to a smaller allocation than required. * Update tests Golden tests are simply rewritten with the version. The main difference is the pervasiveness of `memref` in the function signatures and the corresponding memref-to-pointer conversions (or lack thereof) around function calls. * Create a pointer2memref op if we cast between different types * Do not use argv special case * Fix bug with multiple definitions of malloc * subindexlowering: bitcast after gepop * Update test * Bugfixes * Cleanup * Fix LLVM bump * Fix tests * Generalize tests Co-authored-by: William S. Moses <[email protected]> Co-authored-by: Ivan Radanov Ivanov <[email protected]>
1 parent ec6dc13 commit 11ed9b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1984
-675
lines changed

include/polygeist/BarrierUtils.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "mlir/IR/Block.h"
1818
#include "polygeist/Ops.h"
1919
#include "llvm/ADT/SetVector.h"
20-
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
20+
#include <mlir/Dialect/Arith/IR/Arith.h>
2121

2222
std::pair<mlir::Block *, mlir::Block::iterator>
2323
findInsertionPointAfterLoopOperands(mlir::scf::ParallelOp op);
@@ -41,14 +41,15 @@ emitIterationCounts(mlir::OpBuilder &rewriter, mlir::scf::ParallelOp op) {
4141
return iterationCounts;
4242
}
4343

44-
mlir::LLVM::LLVMFuncOp GetOrCreateMallocFunction(mlir::ModuleOp module);
44+
mlir::Value callMalloc(mlir::OpBuilder &builder, mlir::ModuleOp module,
45+
mlir::Location loc, mlir::Value arg);
4546
mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(mlir::ModuleOp module);
4647

4748
template <typename T>
48-
static T allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
49-
mlir::ValueRange iterationCounts,
50-
bool alloca = true,
51-
mlir::DataLayout *DLI = nullptr) {
49+
static mlir::Value
50+
allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
51+
mlir::ValueRange iterationCounts, bool alloca = true,
52+
mlir::DataLayout *DLI = nullptr) {
5253
using namespace mlir;
5354
SmallVector<int64_t> bufferSize(iterationCounts.size(),
5455
ShapedType::kDynamicSize);
@@ -75,7 +76,7 @@ static T allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value,
7576
}
7677

7778
template <>
78-
mlir::LLVM::AllocaOp allocateTemporaryBuffer<mlir::LLVM::AllocaOp>(
79+
mlir::Value allocateTemporaryBuffer<mlir::LLVM::AllocaOp>(
7980
mlir::OpBuilder &rewriter, mlir::Value value,
8081
mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI) {
8182
using namespace mlir;
@@ -92,7 +93,7 @@ mlir::LLVM::AllocaOp allocateTemporaryBuffer<mlir::LLVM::AllocaOp>(
9293
}
9394

9495
template <>
95-
mlir::LLVM::CallOp allocateTemporaryBuffer<mlir::LLVM::CallOp>(
96+
mlir::Value allocateTemporaryBuffer<mlir::LLVM::CallOp>(
9697
mlir::OpBuilder &rewriter, mlir::Value value,
9798
mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI) {
9899
using namespace mlir;
@@ -113,7 +114,7 @@ mlir::LLVM::CallOp allocateTemporaryBuffer<mlir::LLVM::CallOp>(
113114
value.getLoc(), sz.getType(), iter));
114115
}
115116
auto m = val->getParentOfType<ModuleOp>();
116-
auto allocfn = GetOrCreateMallocFunction(m);
117-
return rewriter.create<LLVM::CallOp>(value.getLoc(), allocfn, sz);
117+
return callMalloc(rewriter, m, value.getLoc(), sz);
118118
}
119+
119120
#endif // MLIR_LIB_DIALECT_SCF_TRANSFORMS_BARRIERUTILS_H_

include/polygeist/Passes/Passes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ std::unique_ptr<Pass> detectReductionPass();
2424
std::unique_ptr<Pass> createRemoveTrivialUsePass();
2525
std::unique_ptr<Pass> createParallelLowerPass();
2626
std::unique_ptr<Pass>
27-
createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options);
27+
createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options,
28+
bool useCStyleMemRef);
2829
std::unique_ptr<Pass> createConvertPolygeistToLLVMPass();
2930
std::unique_ptr<Pass> createForBreakToWhilePass();
3031

@@ -43,7 +44,7 @@ template <typename ConcreteDialect>
4344
void registerDialect(DialectRegistry &registry);
4445

4546
namespace arith {
46-
class ArithmeticDialect;
47+
class ArithDialect;
4748
} // end namespace arith
4849

4950
namespace scf {

include/polygeist/Passes/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def SCFCanonicalizeFor : Pass<"canonicalize-scf-for"> {
6868
def ForBreakToWhile : Pass<"for-break-to-while"> {
6969
let summary = "Rewrite scf.for(scf.if) to scf.while";
7070
let constructor = "mlir::polygeist::createForBreakToWhilePass()";
71-
let dependentDialects = ["arith::ArithmeticDialect"];
71+
let dependentDialects = ["arith::ArithDialect"];
7272
}
7373

7474
def ParallelLICM : Pass<"parallel-licm"> {
@@ -130,7 +130,11 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp">
130130
Option<"dataLayout", "data-layout", "std::string",
131131
/*default=*/"\"\"",
132132
"String description (LLVM format) of the data layout that is "
133-
"expected on the produced module">
133+
"expected on the produced module">,
134+
Option<"useCStyleMemRef", "use-c-style-memref", "bool",
135+
/*default=*/"true",
136+
"Use C-style nested-array lowering of memref instead of "
137+
"the default MLIR descriptor structure">
134138
];
135139
}
136140

lib/polygeist/Ops.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "polygeist/Ops.h"
10-
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
1111
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1212
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1313
#include "mlir/IR/AffineExpr.h"
@@ -20,7 +20,7 @@
2020
#include "polygeist/PolygeistOps.cpp.inc"
2121

2222
#include "mlir/Dialect/Affine/IR/AffineOps.h"
23-
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
23+
#include "mlir/Dialect/Arith/Utils/Utils.h"
2424
#include "mlir/Dialect/Func/IR/FuncOps.h"
2525
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2626
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -384,7 +384,7 @@ static bool mayAlias(Value v, Value v2) {
384384

385385
if (auto glob = v.getDefiningOp<memref::GetGlobalOp>()) {
386386
if (auto Aglob = v2.getDefiningOp<memref::GetGlobalOp>()) {
387-
return glob.name() == Aglob.name();
387+
return glob.getName() == Aglob.getName();
388388
}
389389
}
390390

@@ -644,7 +644,7 @@ struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
644644
subindex.getSource());
645645
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
646646
if (loadOp.getMemref() == subindex) {
647-
SmallVector<Value, 4> indices = loadOp.indices();
647+
SmallVector<Value, 4> indices = loadOp.getIndices();
648648
if (subindex.getType().cast<MemRefType>().getShape().size() ==
649649
subindex.getSource()
650650
.getType()
@@ -676,7 +676,7 @@ struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
676676
}
677677
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
678678
if (storeOp.getMemref() == subindex) {
679-
SmallVector<Value, 4> indices = storeOp.indices();
679+
SmallVector<Value, 4> indices = storeOp.getIndices();
680680
if (subindex.getType().cast<MemRefType>().getShape().size() ==
681681
subindex.getSource()
682682
.getType()
@@ -707,7 +707,7 @@ struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
707707
}
708708
} else if (auto storeOp = dyn_cast<memref::AtomicRMWOp>(use.getOwner())) {
709709
if (storeOp.getMemref() == subindex) {
710-
SmallVector<Value, 4> indices = storeOp.indices();
710+
SmallVector<Value, 4> indices = storeOp.getIndices();
711711
if (subindex.getType().cast<MemRefType>().getShape().size() ==
712712
subindex.getSource()
713713
.getType()
@@ -733,7 +733,7 @@ struct SimplifySubIndexUsers : public OpRewritePattern<SubIndexOp> {
733733
.getShape()
734734
.size() == indices.size());
735735
rewriter.replaceOpWithNewOp<memref::AtomicRMWOp>(
736-
storeOp, storeOp.getType(), storeOp.kind(), storeOp.getValue(),
736+
storeOp, storeOp.getType(), storeOp.getKind(), storeOp.getValue(),
737737
subindex.getSource(), indices);
738738
changed = true;
739739
}
@@ -843,7 +843,7 @@ struct SimplifySubViewUsers : public OpRewritePattern<memref::SubViewOp> {
843843
subindex.getSource());
844844
} else if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
845845
if (loadOp.getMemref() == subindex) {
846-
SmallVector<Value, 4> indices = loadOp.indices();
846+
SmallVector<Value, 4> indices = loadOp.getIndices();
847847
if (subindex.getType().cast<MemRefType>().getShape().size() ==
848848
subindex.getSource()
849849
.getType()
@@ -878,7 +878,7 @@ struct SimplifySubViewUsers : public OpRewritePattern<memref::SubViewOp> {
878878
}
879879
} else if (auto storeOp = dyn_cast<memref::StoreOp>(use.getOwner())) {
880880
if (storeOp.getMemref() == subindex) {
881-
SmallVector<Value, 4> indices = storeOp.indices();
881+
SmallVector<Value, 4> indices = storeOp.getIndices();
882882
if (subindex.getType().cast<MemRefType>().getShape().size() ==
883883
subindex.getSource()
884884
.getType()
@@ -1539,7 +1539,7 @@ class MetaPointer2Memref final : public OpRewritePattern<Op> {
15391539
template <>
15401540
Value MetaPointer2Memref<memref::LoadOp>::computeIndex(
15411541
memref::LoadOp op, size_t i, PatternRewriter &rewriter) const {
1542-
return op.indices()[i];
1542+
return op.getIndices()[i];
15431543
}
15441544

15451545
template <>
@@ -1551,7 +1551,7 @@ void MetaPointer2Memref<memref::LoadOp>::rewrite(
15511551
template <>
15521552
Value MetaPointer2Memref<memref::StoreOp>::computeIndex(
15531553
memref::StoreOp op, size_t i, PatternRewriter &rewriter) const {
1554-
return op.indices()[i];
1554+
return op.getIndices()[i];
15551555
}
15561556

15571557
template <>
@@ -2551,15 +2551,15 @@ struct ConstantRankReduction : public OpRewritePattern<memref::AllocaOp> {
25512551
for (auto u : op->getResult(0).getUsers()) {
25522552
if (auto load = dyn_cast<memref::LoadOp>(u)) {
25532553
if (!set) {
2554-
for (auto i : load.indices()) {
2554+
for (auto i : load.getIndices()) {
25552555
IntegerAttr constValue;
25562556
if (!matchPattern(i, m_Constant(&constValue)))
25572557
return failure();
25582558
v.push_back(constValue.getValue().getZExtValue());
25592559
}
25602560
set = true;
25612561
} else {
2562-
for (auto pair : llvm::zip(load.indices(), v)) {
2562+
for (auto pair : llvm::zip(load.getIndices(), v)) {
25632563
IntegerAttr constValue;
25642564
if (!matchPattern(std::get<0>(pair), m_Constant(&constValue)))
25652565
return failure();
@@ -2626,7 +2626,7 @@ struct ConstantRankReduction : public OpRewritePattern<memref::AllocaOp> {
26262626
}
26272627
if (auto store = dyn_cast<memref::StoreOp>(u)) {
26282628
Value cond = nullptr;
2629-
for (auto pair : llvm::zip(store.indices(), v)) {
2629+
for (auto pair : llvm::zip(store.getIndices(), v)) {
26302630
auto val = rewriter.create<arith::CmpIOp>(
26312631
store.getLoc(), CmpIPredicate::eq, std::get<0>(pair),
26322632
rewriter.create<arith::ConstantIndexOp>(store.getLoc(),

lib/polygeist/Passes/AffineCFG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "llvm/ADT/SmallSet.h"
1616
#include "llvm/Support/Debug.h"
1717
#include <deque>
18-
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
18+
#include <mlir/Dialect/Arith/IR/Arith.h>
1919

2020
#define DEBUG_TYPE "affine-cfg"
2121

lib/polygeist/Passes/BarrierRemovalContinuation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "PassDetails.h"
1515

1616
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
17-
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/MemRef/IR/MemRef.h"

lib/polygeist/Passes/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms
2424

2525
LINK_LIBS PUBLIC
2626
MLIRAffineDialect
27-
MLIRArithmeticDialect
27+
MLIRArithDialect
2828
MLIRAsyncDialect
2929
MLIRAffineUtils
3030
MLIRFuncDialect

lib/polygeist/Passes/CanonicalizeFor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "mlir/IR/Matchers.h"
1010
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1111
#include "polygeist/Passes/Passes.h"
12-
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
12+
#include <mlir/Dialect/Arith/IR/Arith.h>
1313

1414
using namespace mlir;
1515
using namespace mlir::scf;

0 commit comments

Comments
 (0)