Skip to content

Commit 510bc58

Browse files
authored
Primitive copy buffer elimination (#257)
* Eliminate read-write copy buffers * Fix * Fix test * add doublebuf test * Fix test
1 parent 11eb047 commit 510bc58

File tree

3 files changed

+460
-1
lines changed

3 files changed

+460
-1
lines changed

lib/polygeist/Ops.cpp

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,6 +3951,258 @@ struct RemoveAffineParallelSingleIter
39513951
}
39523952
};
39533953

3954+
template <typename T> struct BufferElimination : public OpRewritePattern<T> {
3955+
using OpRewritePattern<T>::OpRewritePattern;
3956+
3957+
static bool legalFor(T op, AffineForOp afFor) {
3958+
for (auto lb : afFor.getLowerBoundMap().getResults()) {
3959+
auto opd = lb.dyn_cast<AffineConstantExpr>();
3960+
if (!opd)
3961+
return false;
3962+
if (opd.getValue() != 0)
3963+
return false;
3964+
}
3965+
auto S = op.getType().getShape();
3966+
if (S.size() != 1)
3967+
return false;
3968+
3969+
for (auto lb : afFor.getUpperBoundMap().getResults()) {
3970+
if (auto opd = lb.dyn_cast<AffineConstantExpr>()) {
3971+
if (S[0] != -1) {
3972+
if (S[0] != opd.getValue())
3973+
return false;
3974+
} else {
3975+
IntegerAttr iattr;
3976+
if (!matchPattern(op.getOperand(0), m_Constant(&iattr)))
3977+
return false;
3978+
if (iattr.getValue() != opd.getValue())
3979+
return false;
3980+
}
3981+
continue;
3982+
}
3983+
if (auto opd = lb.dyn_cast<AffineDimExpr>()) {
3984+
if (S[0] != -1)
3985+
return false;
3986+
if (afFor.getUpperBoundOperands()[opd.getPosition()] !=
3987+
op.getOperand(0))
3988+
return false;
3989+
continue;
3990+
}
3991+
if (auto opd = lb.dyn_cast<AffineSymbolExpr>()) {
3992+
if (S[0] != -1)
3993+
return false;
3994+
if (afFor.getUpperBoundOperands()
3995+
[opd.getPosition() + afFor.getUpperBoundMap().getNumDims()] !=
3996+
op.getOperand(0))
3997+
return false;
3998+
continue;
3999+
}
4000+
4001+
return false;
4002+
}
4003+
return true;
4004+
}
4005+
4006+
LogicalResult matchAndRewrite(T op,
4007+
PatternRewriter &rewriter) const override {
4008+
if (isCaptured(op))
4009+
return failure();
4010+
4011+
for (auto U : op->getResult(0).getUsers()) {
4012+
if (auto load = dyn_cast<AffineLoadOp>(U)) {
4013+
AffineMap map = load.getAffineMapAttr().getValue();
4014+
if (map.getNumResults() != 1)
4015+
continue;
4016+
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
4017+
if (!opd)
4018+
continue;
4019+
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
4020+
.dyn_cast<BlockArgument>();
4021+
if (!val)
4022+
continue;
4023+
4024+
AffineForOp copyOutOfBuffer =
4025+
dyn_cast<AffineForOp>(val.getOwner()->getParentOp());
4026+
if (!copyOutOfBuffer)
4027+
continue;
4028+
if (copyOutOfBuffer.getNumResults())
4029+
continue;
4030+
4031+
if (!legalFor(op, copyOutOfBuffer))
4032+
continue;
4033+
4034+
if (load->getParentOp() != copyOutOfBuffer)
4035+
continue;
4036+
if (!llvm::hasNItems(*copyOutOfBuffer.getBody(), 3))
4037+
continue;
4038+
4039+
auto store = dyn_cast<AffineStoreOp>(load->getNextNode());
4040+
if (!store)
4041+
continue;
4042+
4043+
Value otherBuf = store.memref();
4044+
4045+
if (load.getAffineMapAttr().getValue() !=
4046+
store.getAffineMapAttr().getValue())
4047+
continue;
4048+
if (!llvm::all_of(
4049+
llvm::zip(load.getMapOperands(), store.getMapOperands()),
4050+
[](std::tuple<Value, Value> v) -> bool {
4051+
return std::get<0>(v) == std::get<1>(v);
4052+
}))
4053+
continue;
4054+
4055+
// Needs to be noalias, otherwise we cannot tell if intermediate users
4056+
// also use the other buffer.
4057+
if (!(otherBuf.getDefiningOp<memref::AllocOp>()) &&
4058+
!(otherBuf.getDefiningOp<memref::AllocaOp>()))
4059+
continue;
4060+
4061+
for (auto U2 : otherBuf.getUsers()) {
4062+
if (auto load = dyn_cast<AffineLoadOp>(U2)) {
4063+
AffineMap map = load.getAffineMapAttr().getValue();
4064+
if (map.getNumResults() != 1)
4065+
continue;
4066+
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
4067+
if (!opd)
4068+
continue;
4069+
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
4070+
.dyn_cast<BlockArgument>();
4071+
if (!val)
4072+
continue;
4073+
4074+
AffineForOp copyIntoBuffer =
4075+
dyn_cast<AffineForOp>(val.getOwner()->getParentOp());
4076+
if (!copyIntoBuffer)
4077+
continue;
4078+
if (copyIntoBuffer.getNumResults())
4079+
continue;
4080+
4081+
if (load->getParentOp() != copyIntoBuffer)
4082+
continue;
4083+
if (!llvm::hasNItems(*copyIntoBuffer.getBody(), 3))
4084+
continue;
4085+
4086+
auto store = dyn_cast<AffineStoreOp>(load->getNextNode());
4087+
if (!store)
4088+
continue;
4089+
4090+
if (load.getAffineMapAttr().getValue() !=
4091+
store.getAffineMapAttr().getValue())
4092+
continue;
4093+
if (!llvm::all_of(
4094+
llvm::zip(load.getMapOperands(), store.getMapOperands()),
4095+
[](std::tuple<Value, Value> v) -> bool {
4096+
return std::get<0>(v) == std::get<1>(v);
4097+
}))
4098+
continue;
4099+
4100+
if (store.memref() != op)
4101+
continue;
4102+
4103+
if (copyIntoBuffer->getBlock() != copyOutOfBuffer->getBlock())
4104+
continue;
4105+
4106+
bool legal = true;
4107+
for (Operation *mod = copyIntoBuffer->getNextNode();
4108+
mod != copyOutOfBuffer; mod = mod->getNextNode()) {
4109+
if (!mod) {
4110+
legal = false;
4111+
break;
4112+
}
4113+
for (auto U3 : otherBuf.getUsers()) {
4114+
if (mod->isAncestor(U3)) {
4115+
legal = false;
4116+
break;
4117+
}
4118+
}
4119+
}
4120+
if (!legal)
4121+
continue;
4122+
4123+
if (!legalFor(op, copyIntoBuffer))
4124+
continue;
4125+
4126+
assert(otherBuf.getType() == op.getType());
4127+
4128+
rewriter.replaceOpWithIf(
4129+
op, otherBuf, nullptr, [&](OpOperand &use) {
4130+
Operation *owner = use.getOwner();
4131+
while (owner &&
4132+
owner->getBlock() != copyIntoBuffer->getBlock()) {
4133+
owner = owner->getParentOp();
4134+
}
4135+
if (!owner)
4136+
return false;
4137+
4138+
return copyIntoBuffer->isBeforeInBlock(owner) &&
4139+
owner->isBeforeInBlock(copyOutOfBuffer);
4140+
});
4141+
4142+
rewriter.setInsertionPoint(copyOutOfBuffer);
4143+
rewriter.clone(*copyIntoBuffer);
4144+
rewriter.eraseOp(copyOutOfBuffer);
4145+
rewriter.eraseOp(copyIntoBuffer);
4146+
// TODO remove
4147+
//
4148+
// %op = alloc
4149+
// stuff(op)
4150+
//
4151+
// copyIntoBuffer(op, otherBuf)
4152+
//
4153+
// stuffToReplace(op)
4154+
//
4155+
// copyOutOfBuffer(otherBuf, op)
4156+
//
4157+
// stuff2(op)
4158+
//
4159+
//
4160+
// BECOMES
4161+
//
4162+
// %op = alloc
4163+
// stuff(op)
4164+
//
4165+
//
4166+
// stuffToReplace(otherBuf)
4167+
//
4168+
// # ERASED copyOutOfBuffer(otherBuf, op)
4169+
// copyIntoBuffer(op, otherBuf)
4170+
//
4171+
// stuff2(op)
4172+
//
4173+
return success();
4174+
}
4175+
}
4176+
}
4177+
}
4178+
return failure();
4179+
}
4180+
};
4181+
4182+
template <typename T> struct SimplifyDeadAllocV2 : public OpRewritePattern<T> {
4183+
using OpRewritePattern<T>::OpRewritePattern;
4184+
4185+
LogicalResult matchAndRewrite(T alloc,
4186+
PatternRewriter &rewriter) const override {
4187+
if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
4188+
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
4189+
return storeOp.value() == alloc;
4190+
if (auto storeOp = dyn_cast<AffineStoreOp>(op))
4191+
return storeOp.value() == alloc;
4192+
if (auto storeOp = dyn_cast<LLVM::StoreOp>(op))
4193+
return storeOp.getValue() == alloc;
4194+
return !isa<memref::DeallocOp>(op);
4195+
}))
4196+
return failure();
4197+
4198+
for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
4199+
rewriter.eraseOp(user);
4200+
4201+
rewriter.eraseOp(alloc);
4202+
return success();
4203+
}
4204+
};
4205+
39544206
void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
39554207
MLIRContext *context) {
39564208
results.insert<
@@ -3962,6 +4214,9 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
39624214
AffineIfSinking, AffineIfSimplification, CombineAffineIfs,
39634215
MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops,
39644216
MergeNestedAffineParallelIf, RemoveAffineParallelSingleIter,
4217+
BufferElimination<memref::AllocaOp>, BufferElimination<memref::AllocOp>,
4218+
SimplifyDeadAllocV2<memref::AllocaOp>,
4219+
SimplifyDeadAllocV2<memref::AllocOp>, SimplifyDeadAllocV2<LLVM::AllocaOp>,
39654220
// RankReduction<memref::AllocaOp, scf::ParallelOp>,
39664221
AggressiveAllocaScopeInliner, InductiveVarRemoval>(context);
39674222
}

0 commit comments

Comments
 (0)