@@ -3951,6 +3951,258 @@ struct RemoveAffineParallelSingleIter
3951
3951
}
3952
3952
};
3953
3953
3954
+ template <typename T> struct BufferElimination : public OpRewritePattern <T> {
3955
+ using OpRewritePattern<T>::OpRewritePattern;
3956
+
3957
+ static bool legalFor (T op, AffineForOp afFor) {
3958
+ for (auto lb : afFor.getLowerBoundMap ().getResults ()) {
3959
+ auto opd = lb.dyn_cast <AffineConstantExpr>();
3960
+ if (!opd)
3961
+ return false ;
3962
+ if (opd.getValue () != 0 )
3963
+ return false ;
3964
+ }
3965
+ auto S = op.getType ().getShape ();
3966
+ if (S.size () != 1 )
3967
+ return false ;
3968
+
3969
+ for (auto lb : afFor.getUpperBoundMap ().getResults ()) {
3970
+ if (auto opd = lb.dyn_cast <AffineConstantExpr>()) {
3971
+ if (S[0 ] != -1 ) {
3972
+ if (S[0 ] != opd.getValue ())
3973
+ return false ;
3974
+ } else {
3975
+ IntegerAttr iattr;
3976
+ if (!matchPattern (op.getOperand (0 ), m_Constant (&iattr)))
3977
+ return false ;
3978
+ if (iattr.getValue () != opd.getValue ())
3979
+ return false ;
3980
+ }
3981
+ continue ;
3982
+ }
3983
+ if (auto opd = lb.dyn_cast <AffineDimExpr>()) {
3984
+ if (S[0 ] != -1 )
3985
+ return false ;
3986
+ if (afFor.getUpperBoundOperands ()[opd.getPosition ()] !=
3987
+ op.getOperand (0 ))
3988
+ return false ;
3989
+ continue ;
3990
+ }
3991
+ if (auto opd = lb.dyn_cast <AffineSymbolExpr>()) {
3992
+ if (S[0 ] != -1 )
3993
+ return false ;
3994
+ if (afFor.getUpperBoundOperands ()
3995
+ [opd.getPosition () + afFor.getUpperBoundMap ().getNumDims ()] !=
3996
+ op.getOperand (0 ))
3997
+ return false ;
3998
+ continue ;
3999
+ }
4000
+
4001
+ return false ;
4002
+ }
4003
+ return true ;
4004
+ }
4005
+
4006
+ LogicalResult matchAndRewrite (T op,
4007
+ PatternRewriter &rewriter) const override {
4008
+ if (isCaptured (op))
4009
+ return failure ();
4010
+
4011
+ for (auto U : op->getResult (0 ).getUsers ()) {
4012
+ if (auto load = dyn_cast<AffineLoadOp>(U)) {
4013
+ AffineMap map = load.getAffineMapAttr ().getValue ();
4014
+ if (map.getNumResults () != 1 )
4015
+ continue ;
4016
+ auto opd = map.getResults ()[0 ].dyn_cast <AffineDimExpr>();
4017
+ if (!opd)
4018
+ continue ;
4019
+ auto val = ((Value)load.getMapOperands ()[opd.getPosition ()])
4020
+ .dyn_cast <BlockArgument>();
4021
+ if (!val)
4022
+ continue ;
4023
+
4024
+ AffineForOp copyOutOfBuffer =
4025
+ dyn_cast<AffineForOp>(val.getOwner ()->getParentOp ());
4026
+ if (!copyOutOfBuffer)
4027
+ continue ;
4028
+ if (copyOutOfBuffer.getNumResults ())
4029
+ continue ;
4030
+
4031
+ if (!legalFor (op, copyOutOfBuffer))
4032
+ continue ;
4033
+
4034
+ if (load->getParentOp () != copyOutOfBuffer)
4035
+ continue ;
4036
+ if (!llvm::hasNItems (*copyOutOfBuffer.getBody (), 3 ))
4037
+ continue ;
4038
+
4039
+ auto store = dyn_cast<AffineStoreOp>(load->getNextNode ());
4040
+ if (!store)
4041
+ continue ;
4042
+
4043
+ Value otherBuf = store.memref ();
4044
+
4045
+ if (load.getAffineMapAttr ().getValue () !=
4046
+ store.getAffineMapAttr ().getValue ())
4047
+ continue ;
4048
+ if (!llvm::all_of (
4049
+ llvm::zip (load.getMapOperands (), store.getMapOperands ()),
4050
+ [](std::tuple<Value, Value> v) -> bool {
4051
+ return std::get<0 >(v) == std::get<1 >(v);
4052
+ }))
4053
+ continue ;
4054
+
4055
+ // Needs to be noalias, otherwise we cannot tell if intermediate users
4056
+ // also use the other buffer.
4057
+ if (!(otherBuf.getDefiningOp <memref::AllocOp>()) &&
4058
+ !(otherBuf.getDefiningOp <memref::AllocaOp>()))
4059
+ continue ;
4060
+
4061
+ for (auto U2 : otherBuf.getUsers ()) {
4062
+ if (auto load = dyn_cast<AffineLoadOp>(U2)) {
4063
+ AffineMap map = load.getAffineMapAttr ().getValue ();
4064
+ if (map.getNumResults () != 1 )
4065
+ continue ;
4066
+ auto opd = map.getResults ()[0 ].dyn_cast <AffineDimExpr>();
4067
+ if (!opd)
4068
+ continue ;
4069
+ auto val = ((Value)load.getMapOperands ()[opd.getPosition ()])
4070
+ .dyn_cast <BlockArgument>();
4071
+ if (!val)
4072
+ continue ;
4073
+
4074
+ AffineForOp copyIntoBuffer =
4075
+ dyn_cast<AffineForOp>(val.getOwner ()->getParentOp ());
4076
+ if (!copyIntoBuffer)
4077
+ continue ;
4078
+ if (copyIntoBuffer.getNumResults ())
4079
+ continue ;
4080
+
4081
+ if (load->getParentOp () != copyIntoBuffer)
4082
+ continue ;
4083
+ if (!llvm::hasNItems (*copyIntoBuffer.getBody (), 3 ))
4084
+ continue ;
4085
+
4086
+ auto store = dyn_cast<AffineStoreOp>(load->getNextNode ());
4087
+ if (!store)
4088
+ continue ;
4089
+
4090
+ if (load.getAffineMapAttr ().getValue () !=
4091
+ store.getAffineMapAttr ().getValue ())
4092
+ continue ;
4093
+ if (!llvm::all_of (
4094
+ llvm::zip (load.getMapOperands (), store.getMapOperands ()),
4095
+ [](std::tuple<Value, Value> v) -> bool {
4096
+ return std::get<0 >(v) == std::get<1 >(v);
4097
+ }))
4098
+ continue ;
4099
+
4100
+ if (store.memref () != op)
4101
+ continue ;
4102
+
4103
+ if (copyIntoBuffer->getBlock () != copyOutOfBuffer->getBlock ())
4104
+ continue ;
4105
+
4106
+ bool legal = true ;
4107
+ for (Operation *mod = copyIntoBuffer->getNextNode ();
4108
+ mod != copyOutOfBuffer; mod = mod->getNextNode ()) {
4109
+ if (!mod) {
4110
+ legal = false ;
4111
+ break ;
4112
+ }
4113
+ for (auto U3 : otherBuf.getUsers ()) {
4114
+ if (mod->isAncestor (U3)) {
4115
+ legal = false ;
4116
+ break ;
4117
+ }
4118
+ }
4119
+ }
4120
+ if (!legal)
4121
+ continue ;
4122
+
4123
+ if (!legalFor (op, copyIntoBuffer))
4124
+ continue ;
4125
+
4126
+ assert (otherBuf.getType () == op.getType ());
4127
+
4128
+ rewriter.replaceOpWithIf (
4129
+ op, otherBuf, nullptr , [&](OpOperand &use) {
4130
+ Operation *owner = use.getOwner ();
4131
+ while (owner &&
4132
+ owner->getBlock () != copyIntoBuffer->getBlock ()) {
4133
+ owner = owner->getParentOp ();
4134
+ }
4135
+ if (!owner)
4136
+ return false ;
4137
+
4138
+ return copyIntoBuffer->isBeforeInBlock (owner) &&
4139
+ owner->isBeforeInBlock (copyOutOfBuffer);
4140
+ });
4141
+
4142
+ rewriter.setInsertionPoint (copyOutOfBuffer);
4143
+ rewriter.clone (*copyIntoBuffer);
4144
+ rewriter.eraseOp (copyOutOfBuffer);
4145
+ rewriter.eraseOp (copyIntoBuffer);
4146
+ // TODO remove
4147
+ //
4148
+ // %op = alloc
4149
+ // stuff(op)
4150
+ //
4151
+ // copyIntoBuffer(op, otherBuf)
4152
+ //
4153
+ // stuffToReplace(op)
4154
+ //
4155
+ // copyOutOfBuffer(otherBuf, op)
4156
+ //
4157
+ // stuff2(op)
4158
+ //
4159
+ //
4160
+ // BECOMES
4161
+ //
4162
+ // %op = alloc
4163
+ // stuff(op)
4164
+ //
4165
+ //
4166
+ // stuffToReplace(otherBuf)
4167
+ //
4168
+ // # ERASED copyOutOfBuffer(otherBuf, op)
4169
+ // copyIntoBuffer(op, otherBuf)
4170
+ //
4171
+ // stuff2(op)
4172
+ //
4173
+ return success ();
4174
+ }
4175
+ }
4176
+ }
4177
+ }
4178
+ return failure ();
4179
+ }
4180
+ };
4181
+
4182
+ template <typename T> struct SimplifyDeadAllocV2 : public OpRewritePattern <T> {
4183
+ using OpRewritePattern<T>::OpRewritePattern;
4184
+
4185
+ LogicalResult matchAndRewrite (T alloc,
4186
+ PatternRewriter &rewriter) const override {
4187
+ if (llvm::any_of (alloc->getUsers (), [&](Operation *op) {
4188
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
4189
+ return storeOp.value () == alloc;
4190
+ if (auto storeOp = dyn_cast<AffineStoreOp>(op))
4191
+ return storeOp.value () == alloc;
4192
+ if (auto storeOp = dyn_cast<LLVM::StoreOp>(op))
4193
+ return storeOp.getValue () == alloc;
4194
+ return !isa<memref::DeallocOp>(op);
4195
+ }))
4196
+ return failure ();
4197
+
4198
+ for (Operation *user : llvm::make_early_inc_range (alloc->getUsers ()))
4199
+ rewriter.eraseOp (user);
4200
+
4201
+ rewriter.eraseOp (alloc);
4202
+ return success ();
4203
+ }
4204
+ };
4205
+
3954
4206
void TypeAlignOp::getCanonicalizationPatterns (RewritePatternSet &results,
3955
4207
MLIRContext *context) {
3956
4208
results.insert <
@@ -3962,6 +4214,9 @@ void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results,
3962
4214
AffineIfSinking, AffineIfSimplification, CombineAffineIfs,
3963
4215
MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops,
3964
4216
MergeNestedAffineParallelIf, RemoveAffineParallelSingleIter,
4217
+ BufferElimination<memref::AllocaOp>, BufferElimination<memref::AllocOp>,
4218
+ SimplifyDeadAllocV2<memref::AllocaOp>,
4219
+ SimplifyDeadAllocV2<memref::AllocOp>, SimplifyDeadAllocV2<LLVM::AllocaOp>,
3965
4220
// RankReduction<memref::AllocaOp, scf::ParallelOp>,
3966
4221
AggressiveAllocaScopeInliner, InductiveVarRemoval>(context);
3967
4222
}
0 commit comments