|
20 | 20 | #include <mlir/IR/BuiltinTypes.h>
|
21 | 21 | #include <mlir/IR/DialectImplementation.h>
|
22 | 22 | #include <mlir/IR/Dominance.h>
|
| 23 | +#include <mlir/IR/Matchers.h> |
23 | 24 | #include <mlir/IR/PatternMatch.h>
|
24 | 25 | #include <mlir/Transforms/InliningUtils.h>
|
25 | 26 |
|
@@ -269,14 +270,91 @@ struct GenGlobalId : public mlir::OpRewritePattern<mlir::arith::AddIOp> {
|
269 | 270 | return mlir::failure();
|
270 | 271 | }
|
271 | 272 | };
|
| 273 | + |
| 274 | +struct InvertCmpi : public mlir::OpRewritePattern<mlir::arith::CmpIOp> { |
| 275 | + using OpRewritePattern::OpRewritePattern; |
| 276 | + |
| 277 | + mlir::LogicalResult |
| 278 | + matchAndRewrite(mlir::arith::CmpIOp op, |
| 279 | + mlir::PatternRewriter &rewriter) const override { |
| 280 | + |
| 281 | + if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant()) || |
| 282 | + mlir::matchPattern(op.getRhs(), mlir::m_Constant())) |
| 283 | + return mlir::failure(); |
| 284 | + |
| 285 | + using Pred = mlir::arith::CmpIPredicate; |
| 286 | + const std::pair<Pred, Pred> inv[] = { |
| 287 | + // clang-format off |
| 288 | + {Pred::slt, Pred::sgt}, |
| 289 | + {Pred::sle, Pred::sge}, |
| 290 | + {Pred::ult, Pred::ugt}, |
| 291 | + {Pred::ule, Pred::uge}, |
| 292 | + {Pred::eq, Pred::eq}, |
| 293 | + {Pred::ne, Pred::ne}, |
| 294 | + // clang-format on |
| 295 | + }; |
| 296 | + |
| 297 | + auto newPred = [&]() -> Pred { |
| 298 | + auto oldPred = op.getPredicate(); |
| 299 | + for (auto it : inv) { |
| 300 | + if (it.first == oldPred) |
| 301 | + return it.second; |
| 302 | + if (it.second == oldPred) |
| 303 | + return it.first; |
| 304 | + } |
| 305 | + |
| 306 | + llvm_unreachable("Unknown predicate"); |
| 307 | + }(); |
| 308 | + |
| 309 | + rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(op, newPred, op.getRhs(), |
| 310 | + op.getLhs()); |
| 311 | + ; |
| 312 | + return mlir::success(); |
| 313 | + } |
| 314 | +}; |
| 315 | + |
| 316 | +struct ReshapeAlloca : public mlir::OpRewritePattern<mlir::memref::ReshapeOp> { |
| 317 | + using OpRewritePattern::OpRewritePattern; |
| 318 | + |
| 319 | + mlir::LogicalResult |
| 320 | + matchAndRewrite(mlir::memref::ReshapeOp op, |
| 321 | + mlir::PatternRewriter &rewriter) const override { |
| 322 | + auto shapeOp = op.shape().getDefiningOp<mlir::memref::AllocOp>(); |
| 323 | + if (!shapeOp) |
| 324 | + return mlir::failure(); |
| 325 | + |
| 326 | + for (auto user : shapeOp->getUsers()) |
| 327 | + if (!mlir::isa<mlir::memref::StoreOp, mlir::memref::ReshapeOp>(user)) |
| 328 | + return mlir::failure(); |
| 329 | + |
| 330 | + if (!shapeOp.dynamicSizes().empty() || !shapeOp.symbolOperands().empty()) |
| 331 | + return mlir::failure(); |
| 332 | + |
| 333 | + auto func = op->getParentOfType<mlir::FuncOp>(); |
| 334 | + if (!func) |
| 335 | + return mlir::failure(); |
| 336 | + |
| 337 | + if (shapeOp->getParentOp() != func) { |
| 338 | + rewriter.setInsertionPointToStart(&func.getBody().front()); |
| 339 | + } else { |
| 340 | + rewriter.setInsertionPoint(shapeOp); |
| 341 | + } |
| 342 | + |
| 343 | + auto type = shapeOp.getType().cast<mlir::MemRefType>(); |
| 344 | + auto alignment = shapeOp.alignmentAttr().cast<mlir::IntegerAttr>(); |
| 345 | + rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(shapeOp, type, |
| 346 | + alignment); |
| 347 | + return mlir::success(); |
| 348 | + } |
| 349 | +}; |
272 | 350 | } // namespace
|
273 | 351 |
|
274 | 352 | void PlierUtilDialect::getCanonicalizationPatterns(
|
275 | 353 | mlir::RewritePatternSet &results) const {
|
276 | 354 | results.add<DimExpandShape<mlir::tensor::DimOp, mlir::tensor::ExpandShapeOp>,
|
277 | 355 | DimExpandShape<mlir::memref::DimOp, mlir::memref::ExpandShapeOp>,
|
278 |
| - DimInsertSlice, FillExtractSlice, SpirvInputCSE, GenGlobalId>( |
279 |
| - getContext()); |
| 356 | + DimInsertSlice, FillExtractSlice, SpirvInputCSE, GenGlobalId, |
| 357 | + InvertCmpi, ReshapeAlloca>(getContext()); |
280 | 358 | }
|
281 | 359 |
|
282 | 360 | OpaqueType OpaqueType::get(mlir::MLIRContext *context) {
|
@@ -330,35 +408,35 @@ EnforceShapeOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
330 | 408 | }
|
331 | 409 |
|
332 | 410 | namespace {
|
333 |
| -struct EnforceShapeDim : public mlir::OpRewritePattern<mlir::memref::DimOp> { |
334 |
| - using mlir::OpRewritePattern<mlir::memref::DimOp>::OpRewritePattern; |
| 411 | +template <typename DimOp> |
| 412 | +struct EnforceShapeDim : public mlir::OpRewritePattern<DimOp> { |
| 413 | + using mlir::OpRewritePattern<DimOp>::OpRewritePattern; |
335 | 414 |
|
336 | 415 | mlir::LogicalResult
|
337 |
| - matchAndRewrite(mlir::memref::DimOp op, |
338 |
| - mlir::PatternRewriter &rewriter) const override { |
339 |
| - auto enforce_op = mlir::dyn_cast_or_null<plier::EnforceShapeOp>( |
340 |
| - op.source().getDefiningOp()); |
341 |
| - if (!enforce_op) { |
| 416 | + matchAndRewrite(DimOp op, mlir::PatternRewriter &rewriter) const override { |
| 417 | + auto enforceOp = |
| 418 | + op.source().template getDefiningOp<plier::EnforceShapeOp>(); |
| 419 | + if (!enforceOp) |
342 | 420 | return mlir::failure();
|
343 |
| - } |
344 |
| - auto const_ind = plier::getConstVal<mlir::IntegerAttr>(op.index()); |
345 |
| - if (!const_ind) { |
| 421 | + |
| 422 | + auto constInd = mlir::getConstantIntValue(op.index()); |
| 423 | + if (!constInd) |
346 | 424 | return mlir::failure();
|
347 |
| - } |
348 |
| - auto index = const_ind.getInt(); |
349 |
| - if (index < 0 || index >= static_cast<int64_t>(enforce_op.sizes().size())) { |
| 425 | + |
| 426 | + auto index = *constInd; |
| 427 | + if (index < 0 || index >= static_cast<int64_t>(enforceOp.sizes().size())) |
350 | 428 | return mlir::failure();
|
351 |
| - } |
352 | 429 |
|
353 |
| - rewriter.replaceOp(op, enforce_op.sizes()[static_cast<unsigned>(index)]); |
| 430 | + rewriter.replaceOp(op, enforceOp.sizes()[static_cast<unsigned>(index)]); |
354 | 431 | return mlir::success();
|
355 | 432 | }
|
356 | 433 | };
|
357 | 434 | } // namespace
|
358 | 435 |
|
359 | 436 | void EnforceShapeOp::getCanonicalizationPatterns(
|
360 | 437 | ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
|
361 |
| - results.insert<EnforceShapeDim>(context); |
| 438 | + results.insert<EnforceShapeDim<mlir::tensor::DimOp>, |
| 439 | + EnforceShapeDim<mlir::memref::DimOp>>(context); |
362 | 440 | }
|
363 | 441 |
|
364 | 442 | mlir::LogicalResult
|
|
0 commit comments