|
36 | 36 | #include "llvm/ADT/STLExtras.h" |
37 | 37 | #include "llvm/ADT/SmallVector.h" |
38 | 38 | #include "llvm/ADT/TypeSwitch.h" |
| 39 | +#include "llvm/ADT/bit.h" |
39 | 40 | #include "llvm/Support/Casting.h" |
40 | 41 | #include "llvm/Support/LogicalResult.h" |
41 | 42 | #include "llvm/Support/raw_ostream.h" |
@@ -781,30 +782,27 @@ namespace { |
781 | 782 | /// | 2x32x16 | [1, 16] | 2x32x1 | |
782 | 783 | FailureOr<VectorType> getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout, |
783 | 784 | VectorType originalType) { |
784 | | - llvm::SmallVector<int64_t, 2> distributedShape; |
785 | 785 | if (!layout) |
786 | 786 | return failure(); |
787 | 787 |
|
788 | | - auto laneLayout = layout.getLaneLayout(); |
789 | | - assert((originalType.getRank() == 2 || originalType.getRank() == 3) && |
790 | | - "expecting 2D or 3D shape for the original vector type"); |
791 | | - assert(laneLayout.size() == 2 && "expecting 2D shape for the wi layout"); |
792 | | - // Original type can be 2D or 3D (array_length > 1), the last two dims are the |
793 | | - // block shape. |
794 | | - auto blockShape = originalType.getShape().take_back(2); |
795 | | - // Check if the block vector shape can be distributed evenly. |
796 | | - if (blockShape[0] % laneLayout[0] != 0 || blockShape[1] % laneLayout[1] != 0) |
797 | | - return failure(); |
798 | | - |
799 | | - if (originalType.getRank() == 3) { |
800 | | - distributedShape.push_back(originalType.getShape()[0]); |
801 | | - } |
802 | | - for (unsigned i = 0; i < 2; ++i) { |
803 | | - distributedShape.push_back(blockShape[i] / laneLayout[i]); |
| 788 | + auto laneLayout = layout.getLaneLayout().asArrayRef(); |
| 789 | + assert(originalType.getShape().size() >= laneLayout.size() && |
| 790 | + "Rank of the original vector type should be greater or equal to the " |
| 791 | + "size of the lane layout to distribute the vector type."); |
| 792 | + SmallVector<int64_t> distributedShape(originalType.getShape()); |
| 793 | + /// Only distribute the last `laneLayout.size()` dimensions. The remaining |
| 794 | + /// dimensions are not distributed. |
| 795 | + unsigned distributionStart = originalType.getRank() - laneLayout.size(); |
| 796 | + for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { |
| 797 | + if (i < distributionStart) { |
| 798 | + continue; |
| 799 | + } |
| 800 | + /// Check if the dimension can be distributed evenly. |
| 801 | + if (dim % laneLayout[i - distributionStart] != 0) |
| 802 | + return failure(); |
| 803 | + distributedShape[i] = dim / laneLayout[i - distributionStart]; |
804 | 804 | } |
805 | | - auto newVectorType = |
806 | | - VectorType::get(distributedShape, originalType.getElementType()); |
807 | | - return newVectorType; |
| 805 | + return VectorType::get(distributedShape, originalType.getElementType()); |
808 | 806 | } |
809 | 807 |
|
810 | 808 | static VectorType getDistributedVectorType(xegpu::LayoutAttr layout, |
@@ -1028,15 +1026,14 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern { |
1028 | 1026 | return rewriter.notifyMatchFailure( |
1029 | 1027 | storeOp, "the source tensor descriptor lacks sg_map attribute"); |
1030 | 1028 |
|
1031 | | - if (storeOp.getTensorDescType().getShape().size() != 2) |
1032 | | - return rewriter.notifyMatchFailure(storeOp, "unsupported shape"); |
1033 | | - |
1034 | | - auto distriburtedTypeByWarpOp = |
| 1029 | + auto distributedTypeByWarpOpOrFailure = |
1035 | 1030 | getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType()); |
1036 | | - if (failed(distriburtedTypeByWarpOp)) |
| 1031 | + if (failed(distributedTypeByWarpOpOrFailure)) |
1037 | 1032 | return rewriter.notifyMatchFailure(storeOp, |
1038 | 1033 | "Failed to distribute the type"); |
1039 | | - VectorType distributedTypeByWarpOp = distriburtedTypeByWarpOp.value(); |
| 1034 | + VectorType distributedTypeByWarpOp = |
| 1035 | + distributedTypeByWarpOpOrFailure.value(); |
| 1036 | + llvm::errs() << "distributed type: " << distributedTypeByWarpOp << "\n"; |
1040 | 1037 |
|
1041 | 1038 | SmallVector<size_t> newRetIndices; |
1042 | 1039 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
@@ -1066,7 +1063,8 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern { |
1066 | 1063 | newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[1])); |
1067 | 1064 |
|
1068 | 1065 | rewriter.create<xegpu::StoreNdOp>(newWarpOp.getLoc(), TypeRange{}, |
1069 | | - newStoreOperands, storeOp->getAttrs()); |
| 1066 | + newStoreOperands); |
| 1067 | + storeOp->setDialectAttrs(storeOp->getDialectAttrs()); |
1070 | 1068 | rewriter.eraseOp(storeOp); |
1071 | 1069 | return success(); |
1072 | 1070 | } |
|
0 commit comments