@@ -961,10 +961,9 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
961961}
962962
963963namespace {
964- LinearLayout chooseStMatrixLayoutLeadingOffset (
965- MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
966- ArrayRef<unsigned > paddedRepShape, ArrayRef<unsigned > order,
967- int swizzleByteSize) {
964+ LinearLayout chooseStMatrixLayoutLeadingOffset (MLIRContext *ctx,
965+ RankedTensorType tensorTy,
966+ int swizzleByteSize) {
968967 int perPhase;
969968 int maxPhase;
970969 if (swizzleByteSize == 32 ) {
@@ -1064,9 +1063,9 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
10641063 {{S (" offset" ), layout.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
10651064}
10661065
1067- LinearLayout chooseStMatrixLayoutNoLeadingOffset (
1068- MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef< unsigned > repShape ,
1069- ArrayRef< unsigned > paddedRepShape, ArrayRef<unsigned > order ) {
1066+ LinearLayout chooseStMatrixLayoutNoLeadingOffset (MLIRContext *ctx,
1067+ Attribute encoding ,
1068+ ArrayRef<int64_t > shape ) {
10701069 StringAttr kReg = S (" register" );
10711070 StringAttr kLane = S (" lane" );
10721071 StringAttr kWarp = S (" warp" );
@@ -1081,17 +1080,16 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
10811080 LinearLayout ({{kReg , basesReg}, {kLane , basesLane}}, {kCol , kRow });
10821081
10831082 // Expand the `register` dimension so the size of columns matches `n`.
1084- auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy. getEncoding () );
1083+ auto mma = cast<NvidiaMmaEncodingAttr>(encoding );
10851084 int n = mma.getInstrShape ()[1 ];
10861085 layout *=
10871086 LinearLayout::identity1D (n / layout.getOutDimSize (kCol ), kReg , kCol );
10881087
10891088 // Expand the `warp` dimension according to warpsPerCTA.
10901089 layout *= identityStandardND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 })
10911090 .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1092- auto ret =
1093- combineCtaCgaWithShape (layout, mma.getCTALayout (), tensorTy.getShape ());
1094- auto tensorShapePerCTA = getShapePerCTA (mma, tensorTy.getShape ());
1091+ auto ret = combineCtaCgaWithShape (layout, mma.getCTALayout (), shape);
1092+ auto tensorShapePerCTA = getShapePerCTA (mma, shape);
10951093 llvm::SmallDenseMap<StringAttr, int64_t > namedTensorShape;
10961094 namedTensorShape[kRow ] = tensorShapePerCTA[0 ];
10971095 namedTensorShape[kCol ] = tensorShapePerCTA[1 ];
@@ -1102,19 +1100,100 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
11021100 {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
11031101}
11041102
1103+ LinearLayout chooseLdMatrixLayoutNoLeadingOffset (MLIRContext *ctx,
1104+ SharedEncodingAttr shared,
1105+ DotOperandEncodingAttr dot,
1106+ ArrayRef<int64_t > shape) {
1107+ auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
1108+ auto rank = shape.size ();
1109+ auto opIdx = dot.getOpIdx ();
1110+ int kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
1111+
1112+ StringAttr kReg = S (" register" );
1113+ StringAttr kLane = S (" lane" );
1114+ StringAttr kWarp = S (" warp" );
1115+ StringAttr kBlock = S (" block" );
1116+ StringAttr kInner = opIdx == 0 ? S (" dim1" ) : S (" dim0" );
1117+ StringAttr kOuter = opIdx == 0 ? S (" dim0" ) : S (" dim1" );
1118+
1119+ std::vector<std::vector<int >> basesReg = {{0 , 1 }, {0 , 2 }, {0 , 4 }};
1120+ std::vector<std::vector<int >> basesLane;
1121+ auto numRowsPerTile = 16 ;
1122+ auto numColsPerTile = 16 ;
1123+ int vecSize = shared.getVec ();
1124+ int perPhase = shared.getPerPhase ();
1125+ int maxPhase = shared.getMaxPhase ();
1126+ auto warpsPerCTA = mma.getWarpsPerCTA ();
1127+ // Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1128+ // efficiently. opIdx=0 and opIdx=1 are handled differently.
1129+ if (opIdx == 0 ) {
1130+ // The matrix elements of thread 0 are distributed in the following pattern:
1131+ //
1132+ // col0 col8
1133+ // row0 reg[0-1] reg[4-5]
1134+ // row8 reg[2-3] reg[6-7]
1135+ for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile); logRow++) {
1136+ int row = 1 << logRow;
1137+ basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1138+ }
1139+ basesLane.push_back ({0 , numColsPerTile / 2 });
1140+ // Expand the `register` dimension so the size of columns matches `K`.
1141+ for (int logCol = 0 ; logCol < llvm::Log2_32 (shape[kDim ] / numColsPerTile);
1142+ logCol++) {
1143+ int col = 1 << logCol;
1144+ basesReg.push_back ({0 , numColsPerTile * col});
1145+ }
1146+ } else {
1147+ // The matrix elements of thread 0 are distributed in the following pattern:
1148+ //
1149+ // col0 col8 col16 col24
1150+ // row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1151+ // 8x8
1152+ for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile / 2 ); logRow++) {
1153+ int row = 1 << logRow;
1154+ basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1155+ }
1156+ // 8x16
1157+ basesLane.push_back ({0 , numColsPerTile / 2 });
1158+ // 8x32
1159+ basesLane.push_back ({0 , numColsPerTile});
1160+ // Expand the `register` dimension so the size of columns matches `K`.
1161+ for (int logCol = 0 ;
1162+ logCol < llvm::Log2_32 (shape[kDim ] / (numColsPerTile * 2 )); logCol++) {
1163+ int col = 1 << logCol;
1164+ basesReg.push_back ({0 , (numColsPerTile * 2 ) * col});
1165+ }
1166+ }
1167+ auto layout = LinearLayout (
1168+ {{kReg , basesReg}, {kLane , basesLane}, {kWarp , {}}}, {kOuter , kInner });
1169+ // Expand the `warp` dimension according to warpsPerCTA.
1170+ layout *= broadcastedDotOperandLayout (ctx, warpsPerCTA, mma.getWarpOrder (),
1171+ kDim , kWarp )
1172+ .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1173+ auto ret = combineCtaCgaWithShape (layout, getCTALayout (dot), shape);
1174+ return ret.transposeOuts ({kInner , kOuter })
1175+ .reshapeOuts (
1176+ {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
1177+ }
1178+
11051179} // anonymous namespace
11061180
11071181LinearLayout chooseStMatrixLayout (MLIRContext *ctx, RankedTensorType tensorTy,
1108- ArrayRef<unsigned > repShape,
1109- ArrayRef<unsigned > paddedRepShape,
1110- ArrayRef<unsigned > order,
11111182 int swizzleByteSize) {
11121183 if (swizzleByteSize == 0 )
1113- return chooseStMatrixLayoutNoLeadingOffset (ctx, tensorTy, repShape ,
1114- paddedRepShape, order );
1184+ return chooseStMatrixLayoutNoLeadingOffset (ctx, tensorTy. getEncoding () ,
1185+ tensorTy. getShape () );
11151186 else
1116- return chooseStMatrixLayoutLeadingOffset (
1117- ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
1187+ return chooseStMatrixLayoutLeadingOffset (ctx, tensorTy, swizzleByteSize);
1188+ }
1189+
1190+ LinearLayout chooseLdMatrixLayout (MLIRContext *ctx, Attribute sharedEnc,
1191+ Attribute dotEnc, ArrayRef<int64_t > shape) {
1192+ auto shared = cast<SharedEncodingAttr>(sharedEnc);
1193+ auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1194+ assert (!shared.getHasLeadingOffset () &&
1195+ " Ldmatrix does not support leading offset yet" );
1196+ return chooseLdMatrixLayoutNoLeadingOffset (ctx, shared, dot, shape);
11181197}
11191198
11201199} // namespace mlir::triton::gpu
0 commit comments