@@ -820,8 +820,8 @@ namespace {
820820// stmatrix. These restrictions are retained from legacy code, and we could
821821// relax some of them in the future.
822822bool canUseStMatrix (RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
823- ArrayRef<unsigned > paddedRepShape,
824- ArrayRef< unsigned > order ) {
823+ ArrayRef<unsigned > paddedRepShape, ArrayRef< unsigned > order,
824+ int swizzleByteSize ) {
825825 auto mmaLayout =
826826 mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding ());
827827 if (!mmaLayout || !mmaLayout.isHopper ())
@@ -840,17 +840,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
840840 return false ;
841841 if (paddedRepShape[1 ] % 8 != 0 )
842842 return false ;
843+ if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
844+ swizzleByteSize != 128 )
845+ return false ;
843846 return true ;
844847}
845848
846- } // anonymous namespace
849+ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset (
850+ MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
851+ ArrayRef<unsigned > paddedRepShape, ArrayRef<unsigned > order,
852+ int swizzleByteSize) {
853+ StringAttr kReg = S (" register" );
854+ StringAttr kLane = S (" lane" );
855+ StringAttr kWarp = S (" warp" );
856+ StringAttr kCol = S (" dim1" );
857+ StringAttr kRow = S (" dim0" );
858+ StringAttr kOffset = S (" offset" );
859+
860+ int perPhase;
861+ int maxPhase;
862+ if (swizzleByteSize == 32 ) {
863+ perPhase = 4 ;
864+ maxPhase = 2 ;
865+ } else if (swizzleByteSize == 64 ) {
866+ perPhase = 2 ;
867+ maxPhase = 4 ;
868+ } else if (swizzleByteSize == 128 ) {
869+ perPhase = 1 ;
870+ maxPhase = 8 ;
871+ } else {
872+ llvm::errs () << " Illegal swizzleByteSize: " << swizzleByteSize << " \n " ;
873+ llvm::report_fatal_error (" Illegal swizzleByteSize" );
874+ }
875+
876+ // stmatrix only supports 16-bit elements, and each vector has 8 elements
877+ int elemBitWidth = 16 ;
878+ int vecSize = 8 ;
879+ int numRows = 16 ;
880+ int numCols = 8 * swizzleByteSize / elemBitWidth;
881+
882+ // Construct a single stmatrix.x4 (16x16) tile
883+ std::vector<std::vector<int >> basesReg = {{1 , 0 }, {2 , 0 }, {4 , 0 }};
884+ std::vector<std::vector<int >> basesLane;
885+ for (int logRow = 0 ; logRow < llvm::Log2_32 (numRows); logRow++) {
886+ int row = 1 << logRow;
887+ basesLane.push_back ({vecSize * ((row / perPhase) % maxPhase), row});
888+ }
889+ basesLane.push_back ({8 , 0 });
890+
891+ // Expand the tile's register dimension to fit swizzleByteSize, which is a
892+ // "chunk"
893+ for (int logChunk = 0 ; logChunk < llvm::Log2_32 (numCols / 16 ); logChunk++) {
894+ int chunk = 1 << logChunk;
895+ basesReg.push_back ({16 * chunk, 0 });
896+ }
897+
898+ // Construct the layout for a single chunk
899+ LinearLayout layout =
900+ LinearLayout ({{kReg , basesReg}, {kLane , basesLane}}, {kCol , kRow });
847901
848- std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion (
902+ // Expand the `warp` dimension according to warpsPerCTA.
903+ auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding ());
904+ layout *=
905+ identityND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 }, {kRow , kCol })
906+ .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
907+
908+ // Expand the `register` dimension so the size of columns matches `n`.
909+ int n = mma.getInstrShape ()[1 ];
910+ int numWarpRows = layout.getOutDimSize (kRow );
911+ layout = (layout.reshapeOuts ({{kOffset , layout.getTotalOutDimSize ()}}) *
912+ LinearLayout::identity1D (n / numCols, kReg , kOffset ))
913+ .reshapeOuts ({{kCol , n}, {kRow , numWarpRows}});
914+
915+ auto ret =
916+ combineCtaCgaWithShape (layout, mma.getCTALayout (), tensorTy.getShape ());
917+ return ret.transposeOuts (llvm::to_vector (layout.getOutDimNames ()))
918+ .reshapeOuts ({{kOffset , ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
919+ }
920+
921+ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset (
849922 MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
850923 ArrayRef<unsigned > paddedRepShape, ArrayRef<unsigned > order) {
851- if (!canUseStMatrix (tensorTy, repShape, paddedRepShape, order))
852- return std::nullopt ;
853-
854924 StringAttr kReg = S (" register" );
855925 StringAttr kLane = S (" lane" );
856926 StringAttr kWarp = S (" warp" );
@@ -880,4 +950,23 @@ std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
880950 {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
881951}
882952
953+ } // anonymous namespace
954+
955+ std::optional<LinearLayout>
956+ chooseStMatrixLayout (MLIRContext *ctx, RankedTensorType tensorTy,
957+ ArrayRef<unsigned > repShape,
958+ ArrayRef<unsigned > paddedRepShape,
959+ ArrayRef<unsigned > order, int swizzleByteSize) {
960+ if (!canUseStMatrix (tensorTy, repShape, paddedRepShape, order,
961+ swizzleByteSize))
962+ return std::nullopt ;
963+
964+ if (swizzleByteSize == 0 )
965+ return chooseStMatrixLayoutNoLeadingOffset (ctx, tensorTy, repShape,
966+ paddedRepShape, order);
967+ else
968+ return chooseStMatrixLayoutLeadingOffset (
969+ ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
970+ }
971+
883972} // namespace mlir::triton::gpu
0 commit comments