|
6 | 6 |
|
7 | 7 | #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h" |
8 | 8 | #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" |
| 9 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 10 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
9 | 11 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
10 | 12 |
|
11 | 13 | using namespace mlir; |
@@ -38,9 +40,9 @@ OpFoldResult ToSIMTOp::fold(FoldAdaptor) { |
38 | 40 | return {}; |
39 | 41 | } |
40 | 42 |
|
41 | | -// |
| 43 | +//===----------------------------------------------------------------------===// |
42 | 44 | // TransferGatherOp |
43 | | -// |
| 45 | +//===----------------------------------------------------------------------===// |
44 | 46 |
|
45 | 47 | Speculation::Speculatability TransferGatherOp::getSpeculatability() { |
46 | 48 | if (isa<RankedTensorType>(getSource().getType())) { |
@@ -438,13 +440,221 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser, |
438 | 440 | return parser.addTypeToList(vectorType, result.types); |
439 | 441 | } |
440 | 442 |
|
441 | | -void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results, |
442 | | - MLIRContext *ctx) {} |
| 443 | +static int64_t getVectorRank(Type type) { |
| 444 | + if (auto vecType = dyn_cast<VectorType>(type)) { |
| 445 | + return vecType.getRank(); |
| 446 | + } |
| 447 | + return 0; |
| 448 | +} |
| 449 | + |
| 450 | +struct IndexVecFoldResult { |
| 451 | + Value indexVec; |
| 452 | + AffineMap indexMap; |
| 453 | + bool changed; |
| 454 | +}; |
| 455 | + |
| 456 | +static Value foldTransferGatherIndexVecs( |
| 457 | + TransferGatherOp gatherOp, |
| 458 | + function_ref<IndexVecFoldResult(Value, AffineMap, int64_t)> |
| 459 | + indexVecFolder) { |
| 460 | + bool changed = false; |
| 461 | + SmallVector<Value> newIndexVecs; |
| 462 | + SmallVector<AffineMap> newIndexedMaps; |
| 463 | + SmallVector<bool> indexed(gatherOp.getIndexed().getAsValueRange<BoolAttr>()); |
| 464 | + int64_t currIndexVec = 0; |
| 465 | + for (auto i : llvm::seq<int64_t>(gatherOp.getIndices().size())) { |
| 466 | + if (!indexed[i]) { |
| 467 | + continue; |
| 468 | + } |
| 469 | + Value operand = gatherOp.getIndexVecs()[currIndexVec]; |
| 470 | + AffineMap map = gatherOp.getIndexedMapsArray()[currIndexVec]; |
| 471 | + ++currIndexVec; |
| 472 | + |
| 473 | + auto [indexVec, indexMap, vecChanged] = indexVecFolder(operand, map, i); |
| 474 | + changed |= vecChanged; |
| 475 | + |
| 476 | + if (indexVec) { |
| 477 | + newIndexVecs.push_back(indexVec); |
| 478 | + newIndexedMaps.push_back(indexMap); |
| 479 | + indexed[i] = true; |
| 480 | + } else { |
| 481 | + indexed[i] = false; |
| 482 | + } |
| 483 | + } |
| 484 | + |
| 485 | + if (!changed) { |
| 486 | + return Value(); |
| 487 | + } |
| 488 | + |
| 489 | + OpBuilder b(gatherOp); |
| 490 | + |
| 491 | + SmallVector<Value> operands; |
| 492 | + SmallVector<int32_t> operandSegmentSizes; |
| 493 | + |
| 494 | + // Source. |
| 495 | + operands.push_back(gatherOp.getSource()); |
| 496 | + operandSegmentSizes.push_back(1); |
| 497 | + // Indices. |
| 498 | + SmallVector<Value> indices = gatherOp.getIndices(); |
| 499 | + operands.append(indices); |
| 500 | + operandSegmentSizes.push_back(indices.size()); |
| 501 | + // IndexVecs. |
| 502 | + operands.append(newIndexVecs); |
| 503 | + operandSegmentSizes.push_back(newIndexVecs.size()); |
| 504 | + // Padding. |
| 505 | + operands.push_back(gatherOp.getPadding()); |
| 506 | + operandSegmentSizes.push_back(1); |
| 507 | + // Mask. |
| 508 | + if (gatherOp.getMask()) { |
| 509 | + operands.push_back(gatherOp.getMask()); |
| 510 | + operandSegmentSizes.push_back(1); |
| 511 | + } else { |
| 512 | + operandSegmentSizes.push_back(0); |
| 513 | + } |
| 514 | + |
| 515 | + gatherOp.setIndexedMapsAttr(b.getAffineMapArrayAttr(newIndexedMaps)); |
| 516 | + gatherOp->setOperands(operands); |
| 517 | + gatherOp.setIndexedAttr(b.getBoolArrayAttr(indexed)); |
| 518 | + gatherOp.getProperties().setOperandSegmentSizes(operandSegmentSizes); |
| 519 | + |
| 520 | + return gatherOp.getResult(); |
| 521 | +} |
| 522 | + |
| 523 | +static Value foldTransferGatherFromBroadcast(TransferGatherOp gatherOp) { |
| 524 | + return foldTransferGatherIndexVecs( |
| 525 | + gatherOp, |
| 526 | + [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult { |
| 527 | + auto broadcast = operand.getDefiningOp<vector::BroadcastOp>(); |
| 528 | + if (!broadcast) { |
| 529 | + return {operand, map, false}; |
| 530 | + } |
| 531 | + |
| 532 | + int64_t sourceRank = getVectorRank(broadcast.getSourceType()); |
| 533 | + int64_t operandRank = getVectorRank(broadcast.getResultVectorType()); |
| 534 | + AffineMap newMap = |
| 535 | + map.getSliceMap(operandRank - sourceRank, sourceRank); |
| 536 | + return {broadcast.getSource(), newMap, true}; |
| 537 | + }); |
| 538 | +} |
| 539 | + |
| 540 | +static Value foldTransferGatherFromTranspose(TransferGatherOp gatherOp) { |
| 541 | + return foldTransferGatherIndexVecs( |
| 542 | + gatherOp, |
| 543 | + [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult { |
| 544 | + auto transpose = operand.getDefiningOp<vector::TransposeOp>(); |
| 545 | + if (!transpose) { |
| 546 | + return {operand, map, false}; |
| 547 | + } |
| 548 | + |
| 549 | + AffineMap newMap = |
| 550 | + AffineMap::getPermutationMap( |
| 551 | + invertPermutationVector(transpose.getPermutation()), |
| 552 | + transpose.getContext()) |
| 553 | + .compose(map); |
| 554 | + return {transpose.getVector(), newMap, true}; |
| 555 | + }); |
| 556 | +} |
| 557 | + |
| 558 | +static Value foldTransferGatherFromStep(TransferGatherOp gatherOp) { |
| 559 | + return foldTransferGatherIndexVecs( |
| 560 | + gatherOp, |
| 561 | + [](Value operand, AffineMap map, int64_t index) -> IndexVecFoldResult { |
| 562 | + auto step = operand.getDefiningOp<vector::StepOp>(); |
| 563 | + if (!step) { |
| 564 | + return {operand, map, false}; |
| 565 | + } |
| 566 | + |
| 567 | + assert(map.getNumResults() == 1); |
| 568 | + int64_t resultDim = cast<AffineDimExpr>(map.getResult(0)).getPosition(); |
| 569 | + |
| 570 | + // If the map is indexing along the memory dimension, and the vector is |
| 571 | + // contigious, this is a contigious load on this dimension. |
| 572 | + if (resultDim == index) { |
| 573 | + return {Value(), AffineMap(), true}; |
| 574 | + } |
| 575 | + |
| 576 | + return {operand, map, false}; |
| 577 | + }); |
| 578 | +} |
443 | 579 |
|
444 | 580 | OpFoldResult TransferGatherOp::fold(FoldAdaptor adaptor) { |
| 581 | + if (auto res = foldTransferGatherFromBroadcast(*this)) { |
| 582 | + return res; |
| 583 | + } |
| 584 | + if (auto res = foldTransferGatherFromTranspose(*this)) { |
| 585 | + return res; |
| 586 | + } |
| 587 | + if (auto res = foldTransferGatherFromStep(*this)) { |
| 588 | + return res; |
| 589 | + } |
445 | 590 | return OpFoldResult(); |
446 | 591 | } |
447 | 592 |
|
| 593 | +struct FoldSingleElementIndexVec final : OpRewritePattern<TransferGatherOp> { |
| 594 | + using OpRewritePattern::OpRewritePattern; |
| 595 | + |
| 596 | + LogicalResult matchAndRewrite(TransferGatherOp xferOp, |
| 597 | + PatternRewriter &rewriter) const override { |
| 598 | + |
| 599 | + auto indexVecFolder = [&](Value indexVec, AffineMap map, |
| 600 | + int64_t index) -> IndexVecFoldResult { |
| 601 | + auto vectorTy = cast<VectorType>(indexVec.getType()); |
| 602 | + if (vectorTy.getNumElements() != 1) { |
| 603 | + return {indexVec, map, false}; |
| 604 | + } |
| 605 | + |
| 606 | + // Extract the scalar and add it to the |
| 607 | + // corressponding base. |
| 608 | + OpOperand &base = xferOp.getIndicesMutable()[index]; |
| 609 | + Value extracted = rewriter.create<vector::ExtractOp>( |
| 610 | + xferOp.getLoc(), indexVec, |
| 611 | + SmallVector<int64_t>(vectorTy.getRank(), 0)); |
| 612 | + AffineExpr d0, d1; |
| 613 | + bindDims(xferOp.getContext(), d0, d1); |
| 614 | + Value newIndex = affine::makeComposedAffineApply( |
| 615 | + rewriter, xferOp.getLoc(), d0 + d1, |
| 616 | + ArrayRef<OpFoldResult>{base.get(), extracted}) |
| 617 | + .getResult(); |
| 618 | + base.set(newIndex); |
| 619 | + |
| 620 | + return {Value(), AffineMap(), true}; |
| 621 | + }; |
| 622 | + |
| 623 | + Value newVal = foldTransferGatherIndexVecs(xferOp, indexVecFolder); |
| 624 | + |
| 625 | + if (!newVal) { |
| 626 | + return failure(); |
| 627 | + } |
| 628 | + |
| 629 | + return success(); |
| 630 | + } |
| 631 | +}; |
| 632 | + |
| 633 | +struct FoldContigousGatherToTransferRead final |
| 634 | + : OpRewritePattern<TransferGatherOp> { |
| 635 | + using OpRewritePattern::OpRewritePattern; |
| 636 | + |
| 637 | + LogicalResult matchAndRewrite(TransferGatherOp xferOp, |
| 638 | + PatternRewriter &rewriter) const override { |
| 639 | + if (!xferOp.getIndexVecs().empty()) { |
| 640 | + return failure(); |
| 641 | + } |
| 642 | + |
| 643 | + // Canonicalize to vector.transfer_read. |
| 644 | + rewriter.replaceOpWithNewOp<vector::TransferReadOp>( |
| 645 | + xferOp, xferOp.getVectorType(), xferOp.getSource(), xferOp.getIndices(), |
| 646 | + xferOp.getPermutationMap(), xferOp.getPadding(), xferOp.getMask(), |
| 647 | + xferOp.getInBounds()); |
| 648 | + return success(); |
| 649 | + }; |
| 650 | +}; |
| 651 | + |
| 652 | +void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 653 | + MLIRContext *ctx) { |
| 654 | + results.add<FoldSingleElementIndexVec, FoldContigousGatherToTransferRead>( |
| 655 | + ctx); |
| 656 | +} |
| 657 | + |
448 | 658 | // clang-format off |
449 | 659 | #define GET_OP_CLASSES |
450 | 660 | #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep |
|
0 commit comments