Skip to content

Commit ce84995

Browse files
committed
amend comments
1 parent 6aad0e7 commit ce84995

File tree

2 files changed

+79
-126
lines changed

2 files changed

+79
-126
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,6 @@ def FlattenMemrefsPass : Pass<"flatten-memref"> {
250250
let description = [{
251251

252252
}];
253-
254-
let constructor = "mlir::memref::createFlattenMemrefsPass()";
255253
let dependentDialects = [
256254
"affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
257255
];

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp

Lines changed: 79 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/PatternMatch.h"
2828
#include "mlir/Pass/Pass.h"
2929
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30+
#include "llvm/ADT/TypeSwitch.h"
3031

3132
namespace mlir {
3233
namespace memref {
@@ -148,143 +149,98 @@ static bool checkLayout(Value val) {
148149
}
149150

150151
namespace {
151-
template <typename T>
152-
static Value getTargetMemref(T op) {
153-
if constexpr (std::is_same_v<T, memref::LoadOp>) {
154-
return op.getMemref();
155-
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
156-
return op.getBase();
157-
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
158-
return op.getMemref();
159-
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
160-
return op.getBase();
161-
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
162-
return op.getBase();
163-
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
164-
return op.getBase();
165-
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
166-
return op.getSource();
167-
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
168-
return op.getSource();
169-
}
170-
return {};
152+
static Value getTargetMemref(Operation *op) {
153+
return llvm::TypeSwitch<Operation *, Value>(op)
154+
.template Case<memref::LoadOp, memref::StoreOp>(
155+
[](auto op) { return op.getMemref(); })
156+
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
157+
vector::MaskedStoreOp>(
158+
[](auto op) { return op.getBase(); })
159+
.template Case<vector::TransferReadOp, vector::TransferWriteOp>(
160+
[](auto op) { return op.getSource(); })
161+
.Default([](auto) { return Value{}; });
171162
}
172163

173-
template <typename T>
174-
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
175-
Value offset) {
176-
if constexpr (std::is_same_v<T, memref::LoadOp>) {
177-
auto newLoad = rewriter.create<memref::LoadOp>(
178-
op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
179-
newLoad->setAttrs(op->getAttrs());
180-
rewriter.replaceOp(op, newLoad.getResult());
181-
} else if constexpr (std::is_same_v<T, vector::LoadOp>) {
182-
auto newLoad = rewriter.create<vector::LoadOp>(
183-
op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
184-
newLoad->setAttrs(op->getAttrs());
185-
rewriter.replaceOp(op, newLoad.getResult());
186-
} else if constexpr (std::is_same_v<T, memref::StoreOp>) {
187-
auto newStore = rewriter.create<memref::StoreOp>(
188-
op->getLoc(), op->getOperands().front(), flatMemref,
189-
ValueRange{offset});
190-
newStore->setAttrs(op->getAttrs());
191-
rewriter.replaceOp(op, newStore);
192-
} else if constexpr (std::is_same_v<T, vector::StoreOp>) {
193-
auto newStore = rewriter.create<vector::StoreOp>(
194-
op->getLoc(), op->getOperands().front(), flatMemref,
195-
ValueRange{offset});
196-
newStore->setAttrs(op->getAttrs());
197-
rewriter.replaceOp(op, newStore);
198-
} else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
199-
auto newTransferRead = rewriter.create<vector::TransferReadOp>(
200-
op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
201-
op.getPadding());
202-
rewriter.replaceOp(op, newTransferRead.getResult());
203-
} else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
204-
auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
205-
op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
206-
rewriter.replaceOp(op, newTransferWrite);
207-
} else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
208-
auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
209-
op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
210-
op.getMask(), op.getPassThru());
211-
newMaskedLoad->setAttrs(op->getAttrs());
212-
rewriter.replaceOp(op, newMaskedLoad.getResult());
213-
} else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
214-
auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
215-
op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
216-
op.getValueToStore());
217-
newMaskedStore->setAttrs(op->getAttrs());
218-
rewriter.replaceOp(op, newMaskedStore);
219-
} else {
220-
op.emitOpError("unimplemented: do not know how to replace op.");
221-
}
164+
static void replaceOp(Operation *op, PatternRewriter &rewriter,
165+
Value flatMemref, Value offset) {
166+
auto loc = op->getLoc();
167+
llvm::TypeSwitch<Operation *>(op)
168+
.Case<memref::LoadOp>([&](auto op) {
169+
auto newLoad = rewriter.create<memref::LoadOp>(
170+
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
171+
newLoad->setAttrs(op->getAttrs());
172+
rewriter.replaceOp(op, newLoad.getResult());
173+
})
174+
.Case<memref::StoreOp>([&](auto op) {
175+
auto newStore = rewriter.create<memref::StoreOp>(
176+
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
177+
newStore->setAttrs(op->getAttrs());
178+
rewriter.replaceOp(op, newStore);
179+
})
180+
.Case<vector::LoadOp>([&](auto op) {
181+
auto newLoad = rewriter.create<vector::LoadOp>(
182+
loc, op->getResultTypes(), flatMemref, ValueRange{offset});
183+
newLoad->setAttrs(op->getAttrs());
184+
rewriter.replaceOp(op, newLoad.getResult());
185+
})
186+
.Case<vector::StoreOp>([&](auto op) {
187+
auto newStore = rewriter.create<vector::StoreOp>(
188+
loc, op->getOperands().front(), flatMemref, ValueRange{offset});
189+
newStore->setAttrs(op->getAttrs());
190+
rewriter.replaceOp(op, newStore);
191+
})
192+
.Case<vector::MaskedLoadOp>([&](auto op) {
193+
auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
194+
loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
195+
op.getPassThru());
196+
newMaskedLoad->setAttrs(op->getAttrs());
197+
rewriter.replaceOp(op, newMaskedLoad.getResult());
198+
})
199+
.Case<vector::MaskedStoreOp>([&](auto op) {
200+
auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
201+
loc, flatMemref, ValueRange{offset}, op.getMask(),
202+
op.getValueToStore());
203+
newMaskedStore->setAttrs(op->getAttrs());
204+
rewriter.replaceOp(op, newMaskedStore);
205+
})
206+
.Case<vector::TransferReadOp>([&](auto op) {
207+
auto newTransferRead = rewriter.create<vector::TransferReadOp>(
208+
loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
209+
rewriter.replaceOp(op, newTransferRead.getResult());
210+
})
211+
.Case<vector::TransferWriteOp>([&](auto op) {
212+
auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
213+
loc, op.getVector(), flatMemref, ValueRange{offset});
214+
rewriter.replaceOp(op, newTransferWrite);
215+
})
216+
.Default([&](auto op) {
217+
op->emitOpError("unimplemented: do not know how to replace op.");
218+
});
222219
}
223220

224221
template <typename T>
225-
struct MemRefRewritePatternBase : public OpRewritePattern<T> {
222+
struct MemRefRewritePattern : public OpRewritePattern<T> {
226223
using OpRewritePattern<T>::OpRewritePattern;
227224
LogicalResult matchAndRewrite(T op,
228225
PatternRewriter &rewriter) const override {
229-
Value memref = getTargetMemref<T>(op);
226+
Value memref = getTargetMemref(op);
230227
if (!needFlattening(memref) || !checkLayout(memref))
231-
return rewriter.notifyMatchFailure(op,
232-
"nothing to do or unsupported layout");
228+
return failure();
233229
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
234230
rewriter, op->getLoc(), memref, op.getIndices());
235-
replaceOp<T>(op, rewriter, flatMemref, offset);
231+
replaceOp(op, rewriter, flatMemref, offset);
236232
return success();
237233
}
238234
};
239235

240-
struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
241-
using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
242-
};
243-
244-
struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
245-
using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
246-
};
247-
248-
struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
249-
using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
250-
};
251-
252-
struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
253-
using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
254-
};
255-
256-
struct FlattenVectorMaskedLoad
257-
: public MemRefRewritePatternBase<vector::MaskedLoadOp> {
258-
using MemRefRewritePatternBase<
259-
vector::MaskedLoadOp>::MemRefRewritePatternBase;
260-
};
261-
262-
struct FlattenVectorMaskedStore
263-
: public MemRefRewritePatternBase<vector::MaskedStoreOp> {
264-
using MemRefRewritePatternBase<
265-
vector::MaskedStoreOp>::MemRefRewritePatternBase;
266-
};
267-
268-
struct FlattenVectorTransferRead
269-
: public MemRefRewritePatternBase<vector::TransferReadOp> {
270-
using MemRefRewritePatternBase<
271-
vector::TransferReadOp>::MemRefRewritePatternBase;
272-
};
273-
274-
struct FlattenVectorTransferWrite
275-
: public MemRefRewritePatternBase<vector::TransferWriteOp> {
276-
using MemRefRewritePatternBase<
277-
vector::TransferWriteOp>::MemRefRewritePatternBase;
278-
};
279-
280236
struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
281237
using OpRewritePattern::OpRewritePattern;
282238

283239
LogicalResult matchAndRewrite(memref::SubViewOp op,
284240
PatternRewriter &rewriter) const override {
285241
Value memref = op.getSource();
286242
if (!needFlattening(memref))
287-
return rewriter.notifyMatchFailure(op, "nothing to do");
243+
return rewriter.notifyMatchFailure(op, "already flattened");
288244

289245
if (!checkLayout(memref))
290246
return rewriter.notifyMatchFailure(op, "unsupported layout");
@@ -344,13 +300,12 @@ struct FlattenMemrefsPass
344300
} // namespace
345301

346302
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
347-
patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
348-
FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
349-
FlattenVectorLoad, FlattenVectorStore,
350-
FlattenVectorTransferRead, FlattenVectorTransferWrite>(
351-
patterns.getContext());
352-
}
353-
354-
std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
355-
return std::make_unique<FlattenMemrefsPass>();
303+
patterns
304+
.insert<MemRefRewritePattern<memref::LoadOp>,
305+
MemRefRewritePattern<memref::StoreOp>,
306+
MemRefRewritePattern<vector::LoadOp>,
307+
MemRefRewritePattern<vector::StoreOp>,
308+
MemRefRewritePattern<vector::TransferReadOp>,
309+
MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>(
310+
patterns.getContext());
356311
}

0 commit comments

Comments
 (0)