|
10 | 10 | #include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
|
11 | 11 | #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
|
12 | 12 | #include "iree/compiler/Codegen/Utils/GPUUtils.h"
|
| 13 | +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
13 | 14 | #include "iree/compiler/Utils/Permutation.h"
|
14 | 15 | #include "llvm/ADT/ArrayRef.h"
|
15 | 16 | #include "llvm/ADT/SmallVector.h"
|
@@ -282,6 +283,17 @@ static VectorValue projectVector(RewriterBase &rewriter, Location loc,
|
282 | 283 | return cast<VectorValue>(sliced.getResult());
|
283 | 284 | }
|
284 | 285 |
|
| 286 | +static VectorValue extractSliceAsVector(RewriterBase &rewriter, Location loc, |
| 287 | + Value src, ArrayRef<int64_t> offsets) { |
| 288 | + Value slice = rewriter.create<vector::ExtractOp>(loc, src, offsets); |
| 289 | + // Promote the slicedVector to 0-d vector if it is a scalar. |
| 290 | + if (!isa<VectorType>(slice.getType())) { |
| 291 | + auto promotedType = VectorType::get({}, getElementTypeOrSelf(slice)); |
| 292 | + slice = rewriter.create<vector::BroadcastOp>(loc, promotedType, slice); |
| 293 | + } |
| 294 | + return cast<VectorValue>(slice); |
| 295 | +} |
| 296 | + |
285 | 297 | namespace {
|
286 | 298 |
|
287 | 299 | /// Pattern to distribute `vector.transfer_read` ops with nested layouts.
|
@@ -476,16 +488,9 @@ struct DistributeTransferWrite final
|
476 | 488 | // dimensions are either unrolled or distributed such that this is a
|
477 | 489 | // contiguous slice.
|
478 | 490 | ArrayRef<int64_t> offsetArray(offsets);
|
479 |
| - Value slicedVector = rewriter.create<vector::ExtractOp>( |
480 |
| - writeOp.getLoc(), distributedVector, |
481 |
| - offsetArray.take_front(rank * 2)); |
482 |
| - // Promote the slicedVector to 0-d vector if it is a scalar. |
483 |
| - if (!isa<VectorType>(slicedVector.getType())) { |
484 |
| - auto promotedType = |
485 |
| - VectorType::get({}, getElementTypeOrSelf(slicedVector)); |
486 |
| - slicedVector = rewriter.create<vector::BroadcastOp>( |
487 |
| - writeOp.getLoc(), promotedType, slicedVector); |
488 |
| - } |
| 491 | + VectorValue slicedVector = |
| 492 | + extractSliceAsVector(rewriter, writeOp.getLoc(), distributedVector, |
| 493 | + offsetArray.take_front(rank * 2)); |
489 | 494 |
|
490 | 495 | VectorValue slicedMask = nullptr;
|
491 | 496 | if (mask) {
|
@@ -676,6 +681,104 @@ struct DistributeTransferGather final
|
676 | 681 | int64_t subgroupSize;
|
677 | 682 | };
|
678 | 683 |
|
| 684 | +/// Pattern to distribute `iree_linalg_ext.map_scatter` ops with nested layouts. |
| 685 | +/// Only the input is distributed, since the output is never a vector. The |
| 686 | +/// distribution of the input is similar to that of a vector.transfer_write. |
| 687 | +struct DistributeMapScatter final |
| 688 | + : OpDistributionPattern<IREE::LinalgExt::MapScatterOp> { |
| 689 | + using OpDistributionPattern::OpDistributionPattern; |
| 690 | + |
| 691 | + DistributeMapScatter(MLIRContext *context, Value threadId, |
| 692 | + int64_t subgroupSize) |
| 693 | + : OpDistributionPattern(context), threadId(threadId), |
| 694 | + subgroupSize(subgroupSize) {} |
| 695 | + |
| 696 | + LogicalResult matchAndRewrite(IREE::LinalgExt::MapScatterOp mapScatterOp, |
| 697 | + DistributionSignature &signature, |
| 698 | + PatternRewriter &rewriter) const override { |
| 699 | + auto input = dyn_cast<VectorValue>(mapScatterOp.getInput()); |
| 700 | + if (!input) { |
| 701 | + return rewriter.notifyMatchFailure(mapScatterOp, "input is not a vector"); |
| 702 | + } |
| 703 | + NestedLayoutAttr vectorLayout = |
| 704 | + dyn_cast<NestedLayoutAttr>(signature[input]); |
| 705 | + if (!vectorLayout) { |
| 706 | + return rewriter.notifyMatchFailure(mapScatterOp, |
| 707 | + "non-nested map_scatter layout"); |
| 708 | + } |
| 709 | + if (!isa<MemRefType>(mapScatterOp.getOutput().getType())) { |
| 710 | + return rewriter.notifyMatchFailure(mapScatterOp, |
| 711 | + "distribution expects memrefs"); |
| 712 | + } |
| 713 | + SmallVector<Value> warpIndices, threadIndices; |
| 714 | + if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize, |
| 715 | + vectorLayout, warpIndices, |
| 716 | + threadIndices))) { |
| 717 | + return rewriter.notifyMatchFailure( |
| 718 | + mapScatterOp, "warp or thread tiles have overlapping strides"); |
| 719 | + } |
| 720 | + |
| 721 | + Value distributedVector = getDistributed(rewriter, input, vectorLayout); |
| 722 | + |
| 723 | + Location loc = mapScatterOp.getLoc(); |
| 724 | + Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 725 | + SmallVector<int64_t> distShape = vectorLayout.getDistributedShape(); |
| 726 | + SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout); |
| 727 | + for (auto [idx, offsets] : |
| 728 | + llvm::enumerate(StaticTileOffsetRange(distShape, tileShape))) { |
| 729 | + // Extract the "element vector" from the inner most dimensions. All outer |
| 730 | + // dimensions are either unrolled or distributed such that this is a |
| 731 | + // contiguous slice. |
| 732 | + ArrayRef<int64_t> offsetArray(offsets); |
| 733 | + VectorValue distributedInput = extractSliceAsVector( |
| 734 | + rewriter, loc, distributedVector, |
| 735 | + offsetArray.take_front(vectorLayout.getRank() * 2)); |
| 736 | + |
| 737 | + // Clone the map_scatter op with the "element vector" as the input, and |
| 738 | + // adjust the transformation region to account for the distributed |
| 739 | + // offsets. |
| 740 | + AffineMap permutationMap = |
| 741 | + rewriter.getMultiDimIdentityMap(input.getType().getRank()); |
| 742 | + SmallVector<Value> indices(input.getType().getRank(), zero); |
| 743 | + SmallVector<Value> distributedOffsets = |
| 744 | + getTransferIndicesFromNestedLayout(rewriter, indices, offsets, |
| 745 | + vectorLayout, permutationMap, |
| 746 | + warpIndices, threadIndices); |
| 747 | + IREE::LinalgExt::MapScatterOp distributedMapScatter = |
| 748 | + clone(rewriter, mapScatterOp, mapScatterOp.getResultTypes(), |
| 749 | + {distributedInput, mapScatterOp.getOutput()}); |
| 750 | + int64_t sliceRank = distributedInput.getType().getRank(); |
| 751 | + int64_t rankDiff = input.getType().getRank() - sliceRank; |
| 752 | + // Add the distributed offsets in the map_scatter transformation body. |
| 753 | + auto transformationBuilder = [&](ArrayRef<BlockArgument> newIndices) { |
| 754 | + SmallVector<Value> replacementIndices(distributedOffsets); |
| 755 | + for (auto [i, replacementIdx] : llvm::enumerate(replacementIndices)) { |
| 756 | + // Rank-reduced dimensions can be directly replaced by the distributed |
| 757 | + // index, since their size is 1 in the new map_scatter input. |
| 758 | + if (i < rankDiff) { |
| 759 | + continue; |
| 760 | + } |
| 761 | + // Otherwise, the dimension is a contiguous element dimension, so |
| 762 | + // the mapping is achieved by adding the corresponding block argument |
| 763 | + // to the sliced index. |
| 764 | + BlockArgument newTransformationIdx = newIndices[i - rankDiff]; |
| 765 | + replacementIdx = rewriter.create<arith::AddIOp>( |
| 766 | + loc, newTransformationIdx, replacementIdx); |
| 767 | + } |
| 768 | + return replacementIndices; |
| 769 | + }; |
| 770 | + distributedMapScatter.insertTransformationAtStart( |
| 771 | + rewriter, transformationBuilder, sliceRank); |
| 772 | + } |
| 773 | + |
| 774 | + rewriter.eraseOp(mapScatterOp); |
| 775 | + return success(); |
| 776 | + } |
| 777 | + |
| 778 | + Value threadId; |
| 779 | + int64_t subgroupSize; |
| 780 | +}; |
| 781 | + |
679 | 782 | static VectorValue broadcastToShape(RewriterBase &rewriter, Value source,
|
680 | 783 | ArrayRef<int64_t> shape,
|
681 | 784 | ArrayRef<bool> broadcastedDims) {
|
@@ -2030,8 +2133,8 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
|
2030 | 2133 | int64_t subgroupSize,
|
2031 | 2134 | int64_t maxBitsPerShuffle) {
|
2032 | 2135 | patterns.add<DistributeTransferRead, DistributeTransferWrite,
|
2033 |
| - DistributeTransferGather>(patterns.getContext(), threadId, |
2034 |
| - subgroupSize); |
| 2136 | + DistributeTransferGather, DistributeMapScatter>( |
| 2137 | + patterns.getContext(), threadId, subgroupSize); |
2035 | 2138 | patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
|
2036 | 2139 | patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
|
2037 | 2140 | maxBitsPerShuffle);
|
|
0 commit comments