|
27 | 27 | #include "mlir/IR/PatternMatch.h" |
28 | 28 | #include "mlir/Pass/Pass.h" |
29 | 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 30 | +#include "llvm/ADT/TypeSwitch.h" |
30 | 31 |
|
31 | 32 | namespace mlir { |
32 | 33 | namespace memref { |
@@ -148,143 +149,98 @@ static bool checkLayout(Value val) { |
148 | 149 | } |
149 | 150 |
|
150 | 151 | namespace { |
151 | | -template <typename T> |
152 | | -static Value getTargetMemref(T op) { |
153 | | - if constexpr (std::is_same_v<T, memref::LoadOp>) { |
154 | | - return op.getMemref(); |
155 | | - } else if constexpr (std::is_same_v<T, vector::LoadOp>) { |
156 | | - return op.getBase(); |
157 | | - } else if constexpr (std::is_same_v<T, memref::StoreOp>) { |
158 | | - return op.getMemref(); |
159 | | - } else if constexpr (std::is_same_v<T, vector::StoreOp>) { |
160 | | - return op.getBase(); |
161 | | - } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) { |
162 | | - return op.getBase(); |
163 | | - } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) { |
164 | | - return op.getBase(); |
165 | | - } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) { |
166 | | - return op.getSource(); |
167 | | - } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) { |
168 | | - return op.getSource(); |
169 | | - } |
170 | | - return {}; |
| 152 | +static Value getTargetMemref(Operation *op) { |
| 153 | + return llvm::TypeSwitch<Operation *, Value>(op) |
| 154 | + .template Case<memref::LoadOp, memref::StoreOp>( |
| 155 | + [](auto op) { return op.getMemref(); }) |
| 156 | + .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp, |
| 157 | + vector::MaskedStoreOp>( |
| 158 | + [](auto op) { return op.getBase(); }) |
| 159 | + .template Case<vector::TransferReadOp, vector::TransferWriteOp>( |
| 160 | + [](auto op) { return op.getSource(); }) |
| 161 | + .Default([](auto) { return Value{}; }); |
171 | 162 | } |
172 | 163 |
|
173 | | -template <typename T> |
174 | | -static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, |
175 | | - Value offset) { |
176 | | - if constexpr (std::is_same_v<T, memref::LoadOp>) { |
177 | | - auto newLoad = rewriter.create<memref::LoadOp>( |
178 | | - op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); |
179 | | - newLoad->setAttrs(op->getAttrs()); |
180 | | - rewriter.replaceOp(op, newLoad.getResult()); |
181 | | - } else if constexpr (std::is_same_v<T, vector::LoadOp>) { |
182 | | - auto newLoad = rewriter.create<vector::LoadOp>( |
183 | | - op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); |
184 | | - newLoad->setAttrs(op->getAttrs()); |
185 | | - rewriter.replaceOp(op, newLoad.getResult()); |
186 | | - } else if constexpr (std::is_same_v<T, memref::StoreOp>) { |
187 | | - auto newStore = rewriter.create<memref::StoreOp>( |
188 | | - op->getLoc(), op->getOperands().front(), flatMemref, |
189 | | - ValueRange{offset}); |
190 | | - newStore->setAttrs(op->getAttrs()); |
191 | | - rewriter.replaceOp(op, newStore); |
192 | | - } else if constexpr (std::is_same_v<T, vector::StoreOp>) { |
193 | | - auto newStore = rewriter.create<vector::StoreOp>( |
194 | | - op->getLoc(), op->getOperands().front(), flatMemref, |
195 | | - ValueRange{offset}); |
196 | | - newStore->setAttrs(op->getAttrs()); |
197 | | - rewriter.replaceOp(op, newStore); |
198 | | - } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) { |
199 | | - auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
200 | | - op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, |
201 | | - op.getPadding()); |
202 | | - rewriter.replaceOp(op, newTransferRead.getResult()); |
203 | | - } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) { |
204 | | - auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
205 | | - op->getLoc(), op.getVector(), flatMemref, ValueRange{offset}); |
206 | | - rewriter.replaceOp(op, newTransferWrite); |
207 | | - } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) { |
208 | | - auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
209 | | - op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, |
210 | | - op.getMask(), op.getPassThru()); |
211 | | - newMaskedLoad->setAttrs(op->getAttrs()); |
212 | | - rewriter.replaceOp(op, newMaskedLoad.getResult()); |
213 | | - } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) { |
214 | | - auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
215 | | - op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(), |
216 | | - op.getValueToStore()); |
217 | | - newMaskedStore->setAttrs(op->getAttrs()); |
218 | | - rewriter.replaceOp(op, newMaskedStore); |
219 | | - } else { |
220 | | - op.emitOpError("unimplemented: do not know how to replace op."); |
221 | | - } |
| 164 | +static void replaceOp(Operation *op, PatternRewriter &rewriter, |
| 165 | + Value flatMemref, Value offset) { |
| 166 | + auto loc = op->getLoc(); |
| 167 | + llvm::TypeSwitch<Operation *>(op) |
| 168 | + .Case<memref::LoadOp>([&](auto op) { |
| 169 | + auto newLoad = rewriter.create<memref::LoadOp>( |
| 170 | + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 171 | + newLoad->setAttrs(op->getAttrs()); |
| 172 | + rewriter.replaceOp(op, newLoad.getResult()); |
| 173 | + }) |
| 174 | + .Case<memref::StoreOp>([&](auto op) { |
| 175 | + auto newStore = rewriter.create<memref::StoreOp>( |
| 176 | + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 177 | + newStore->setAttrs(op->getAttrs()); |
| 178 | + rewriter.replaceOp(op, newStore); |
| 179 | + }) |
| 180 | + .Case<vector::LoadOp>([&](auto op) { |
| 181 | + auto newLoad = rewriter.create<vector::LoadOp>( |
| 182 | + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 183 | + newLoad->setAttrs(op->getAttrs()); |
| 184 | + rewriter.replaceOp(op, newLoad.getResult()); |
| 185 | + }) |
| 186 | + .Case<vector::StoreOp>([&](auto op) { |
| 187 | + auto newStore = rewriter.create<vector::StoreOp>( |
| 188 | + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 189 | + newStore->setAttrs(op->getAttrs()); |
| 190 | + rewriter.replaceOp(op, newStore); |
| 191 | + }) |
| 192 | + .Case<vector::MaskedLoadOp>([&](auto op) { |
| 193 | + auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
| 194 | + loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), |
| 195 | + op.getPassThru()); |
| 196 | + newMaskedLoad->setAttrs(op->getAttrs()); |
| 197 | + rewriter.replaceOp(op, newMaskedLoad.getResult()); |
| 198 | + }) |
| 199 | + .Case<vector::MaskedStoreOp>([&](auto op) { |
| 200 | + auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
| 201 | + loc, flatMemref, ValueRange{offset}, op.getMask(), |
| 202 | + op.getValueToStore()); |
| 203 | + newMaskedStore->setAttrs(op->getAttrs()); |
| 204 | + rewriter.replaceOp(op, newMaskedStore); |
| 205 | + }) |
| 206 | + .Case<vector::TransferReadOp>([&](auto op) { |
| 207 | + auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
| 208 | + loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); |
| 209 | + rewriter.replaceOp(op, newTransferRead.getResult()); |
| 210 | + }) |
| 211 | + .Case<vector::TransferWriteOp>([&](auto op) { |
| 212 | + auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
| 213 | + loc, op.getVector(), flatMemref, ValueRange{offset}); |
| 214 | + rewriter.replaceOp(op, newTransferWrite); |
| 215 | + }) |
| 216 | + .Default([&](auto op) { |
| 217 | + op->emitOpError("unimplemented: do not know how to replace op."); |
| 218 | + }); |
222 | 219 | } |
223 | 220 |
|
224 | 221 | template <typename T> |
225 | | -struct MemRefRewritePatternBase : public OpRewritePattern<T> { |
| 222 | +struct MemRefRewritePattern : public OpRewritePattern<T> { |
226 | 223 | using OpRewritePattern<T>::OpRewritePattern; |
227 | 224 | LogicalResult matchAndRewrite(T op, |
228 | 225 | PatternRewriter &rewriter) const override { |
229 | | - Value memref = getTargetMemref<T>(op); |
| 226 | + Value memref = getTargetMemref(op); |
230 | 227 | if (!needFlattening(memref) || !checkLayout(memref)) |
231 | | - return rewriter.notifyMatchFailure(op, |
232 | | - "nothing to do or unsupported layout"); |
| 228 | + return failure(); |
233 | 229 | auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( |
234 | 230 | rewriter, op->getLoc(), memref, op.getIndices()); |
235 | | - replaceOp<T>(op, rewriter, flatMemref, offset); |
| 231 | + replaceOp(op, rewriter, flatMemref, offset); |
236 | 232 | return success(); |
237 | 233 | } |
238 | 234 | }; |
239 | 235 |
|
240 | | -struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> { |
241 | | - using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase; |
242 | | -}; |
243 | | - |
244 | | -struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> { |
245 | | - using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase; |
246 | | -}; |
247 | | - |
248 | | -struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> { |
249 | | - using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase; |
250 | | -}; |
251 | | - |
252 | | -struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> { |
253 | | - using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase; |
254 | | -}; |
255 | | - |
256 | | -struct FlattenVectorMaskedLoad |
257 | | - : public MemRefRewritePatternBase<vector::MaskedLoadOp> { |
258 | | - using MemRefRewritePatternBase< |
259 | | - vector::MaskedLoadOp>::MemRefRewritePatternBase; |
260 | | -}; |
261 | | - |
262 | | -struct FlattenVectorMaskedStore |
263 | | - : public MemRefRewritePatternBase<vector::MaskedStoreOp> { |
264 | | - using MemRefRewritePatternBase< |
265 | | - vector::MaskedStoreOp>::MemRefRewritePatternBase; |
266 | | -}; |
267 | | - |
268 | | -struct FlattenVectorTransferRead |
269 | | - : public MemRefRewritePatternBase<vector::TransferReadOp> { |
270 | | - using MemRefRewritePatternBase< |
271 | | - vector::TransferReadOp>::MemRefRewritePatternBase; |
272 | | -}; |
273 | | - |
274 | | -struct FlattenVectorTransferWrite |
275 | | - : public MemRefRewritePatternBase<vector::TransferWriteOp> { |
276 | | - using MemRefRewritePatternBase< |
277 | | - vector::TransferWriteOp>::MemRefRewritePatternBase; |
278 | | -}; |
279 | | - |
280 | 236 | struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> { |
281 | 237 | using OpRewritePattern::OpRewritePattern; |
282 | 238 |
|
283 | 239 | LogicalResult matchAndRewrite(memref::SubViewOp op, |
284 | 240 | PatternRewriter &rewriter) const override { |
285 | 241 | Value memref = op.getSource(); |
286 | 242 | if (!needFlattening(memref)) |
287 | | - return rewriter.notifyMatchFailure(op, "nothing to do"); |
| 243 | + return rewriter.notifyMatchFailure(op, "already flattened"); |
288 | 244 |
|
289 | 245 | if (!checkLayout(memref)) |
290 | 246 | return rewriter.notifyMatchFailure(op, "unsupported layout"); |
@@ -344,13 +300,12 @@ struct FlattenMemrefsPass |
344 | 300 | } // namespace |
345 | 301 |
|
346 | 302 | void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { |
347 | | - patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview, |
348 | | - FlattenVectorMaskedLoad, FlattenVectorMaskedStore, |
349 | | - FlattenVectorLoad, FlattenVectorStore, |
350 | | - FlattenVectorTransferRead, FlattenVectorTransferWrite>( |
351 | | - patterns.getContext()); |
352 | | -} |
353 | | - |
354 | | -std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() { |
355 | | - return std::make_unique<FlattenMemrefsPass>(); |
| 303 | + patterns |
| 304 | + .insert<MemRefRewritePattern<memref::LoadOp>, |
| 305 | + MemRefRewritePattern<memref::StoreOp>, |
| 306 | + MemRefRewritePattern<vector::LoadOp>, |
| 307 | + MemRefRewritePattern<vector::StoreOp>, |
| 308 | + MemRefRewritePattern<vector::TransferReadOp>, |
| 309 | + MemRefRewritePattern<vector::TransferWriteOp>, FlattenSubview>( |
| 310 | + patterns.getContext()); |
356 | 311 | } |
0 commit comments