@@ -1097,80 +1097,80 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
10971097 {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
10981098}
10991099
1100- LinearLayout chooseLdMatrixLayoutNoLeadingOffset (MLIRContext *ctx ,
1101- SharedEncodingAttr shared ,
1102- DotOperandEncodingAttr dot,
1103- ArrayRef< int64_t > shape) {
1100+ LinearLayout chooseDotLdMatrixLayout (DotOperandEncodingAttr dot ,
1101+ ArrayRef< int64_t > shape, bool needTrans ,
1102+ int32_t elemBitWidth) {
1103+ auto ctx = dot. getContext ();
11041104 auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
11051105 auto rank = shape.size ();
11061106 auto opIdx = dot.getOpIdx ();
1107- int kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
1107+ int kDim = ( opIdx == 0 ) ? rank - 1 : rank - 2 ;
11081108
11091109 StringAttr kReg = S (" register" );
11101110 StringAttr kLane = S (" lane" );
11111111 StringAttr kWarp = S (" warp" );
11121112 StringAttr kBlock = S (" block" );
1113- StringAttr kInner = opIdx == 0 ? S (" dim1" ) : S (" dim0" );
1114- StringAttr kOuter = opIdx == 0 ? S (" dim0" ) : S (" dim1" );
1115-
1116- std::vector<std::vector<int >> basesReg = {{0 , 1 }, {0 , 2 }, {0 , 4 }};
1117- std::vector<std::vector<int >> basesLane;
1118- auto numRowsPerTile = 16 ;
1119- auto numColsPerTile = 16 ;
1120- int vecSize = shared.getVec ();
1121- int perPhase = shared.getPerPhase ();
1122- int maxPhase = shared.getMaxPhase ();
1123- auto warpsPerCTA = mma.getWarpsPerCTA ();
1124- // Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1113+ StringAttr kInner = opIdx == 0 ? (needTrans ? S (" dim0" ) : S (" dim1" ))
1114+ : (needTrans ? S (" dim1" ) : S (" dim0" ));
1115+ StringAttr kOuter = opIdx == 0 ? (needTrans ? S (" dim1" ) : S (" dim0" ))
1116+ : (needTrans ? S (" dim0" ) : S (" dim1" ));
1117+
1118+ std::vector<std::vector<int >> basesReg;
1119+ for (int logReg = 0 ; logReg < llvm::Log2_32 (8 * 16 / elemBitWidth);
1120+ logReg++) {
1121+ auto reg = 1 << logReg;
1122+ basesReg.push_back ({0 , reg});
1123+ }
1124+ std::vector<std::vector<int >> basesLane = {{1 , 0 }, {2 , 0 }, {4 , 0 }};
1125+ int numTileCols;
1126+ // Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
11251127 // efficiently. opIdx=0 and opIdx=1 are handled differently.
11261128 if (opIdx == 0 ) {
1127- // The matrix elements of thread 0 are distributed in the following pattern:
1129+ // The matrix elements of thread 0 are distributed in the following pattern
1130+ // (fp16):
11281131 //
11291132 // col0 col8
11301133 // row0 reg[0-1] reg[4-5]
11311134 // row8 reg[2-3] reg[6-7]
1132- for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile); logRow++) {
1133- int row = 1 << logRow;
1134- basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1135- }
1136- basesLane.push_back ({0 , numColsPerTile / 2 });
1137- // Expand the `register` dimension so the size of columns matches `K`.
1138- for (int logCol = 0 ; logCol < llvm::Log2_32 (shape[kDim ] / numColsPerTile);
1139- logCol++) {
1140- int col = 1 << logCol;
1141- basesReg.push_back ({0 , numColsPerTile * col});
1135+ if (needTrans) {
1136+ assert (elemBitWidth <= 16 && " Only elements smaller than 16 bits are "
1137+ " supported in the transposed mode" );
1138+ basesLane.push_back ({0 , 8 });
1139+ basesLane.push_back ({8 , 0 });
1140+ } else {
1141+ basesLane.push_back ({8 , 0 });
1142+ basesLane.push_back ({0 , 8 * 16 / elemBitWidth});
11421143 }
1144+ numTileCols = 16 * 16 / elemBitWidth;
11431145 } else {
1144- // The matrix elements of thread 0 are distributed in the following pattern:
1146+ // The matrix elements of thread 0 are distributed in the following pattern
1147+ // (fp16):
11451148 //
11461149 // col0 col8 col16 col24
11471150 // row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1148- // 8x8
1149- for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile / 2 ); logRow++) {
1150- int row = 1 << logRow;
1151- basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1152- }
1153- // 8x16
1154- basesLane.push_back ({0 , numColsPerTile / 2 });
1155- // 8x32
1156- basesLane.push_back ({0 , numColsPerTile});
1157- // Expand the `register` dimension so the size of columns matches `K`.
1158- for (int logCol = 0 ;
1159- logCol < llvm::Log2_32 (shape[kDim ] / (numColsPerTile * 2 )); logCol++) {
1160- int col = 1 << logCol;
1161- basesReg.push_back ({0 , (numColsPerTile * 2 ) * col});
1151+ if (needTrans) {
1152+ assert (elemBitWidth <= 16 && " Only elements smaller than 16 bits are "
1153+ " supported in the transposed mode" );
1154+ basesLane.push_back ({8 , 0 });
1155+ basesLane.push_back ({16 , 0 });
1156+ } else {
1157+ basesLane.push_back ({0 , 8 * 16 / elemBitWidth});
1158+ basesLane.push_back ({0 , 16 * 16 / elemBitWidth});
11621159 }
1160+ numTileCols = 32 * 16 / elemBitWidth;
11631161 }
1164- auto layout = LinearLayout (
1165- {{kReg , basesReg}, {kLane , basesLane}, {kWarp , {}}}, {kOuter , kInner });
1162+ // Expand the `register` dimension so the size of columns matches `K`.
1163+ auto layout =
1164+ LinearLayout ({{kReg , basesReg}, {kLane , basesLane}, {kWarp , {}}},
1165+ {kOuter , kInner }) *
1166+ LinearLayout::identity1D (shape[kDim ] / numTileCols, kReg ,
1167+ S (" dim" + std::to_string (kDim )));
11661168 // Expand the `warp` dimension according to warpsPerCTA.
1169+ auto warpsPerCTA = mma.getWarpsPerCTA ();
11671170 layout *= broadcastedDotOperandLayout (ctx, warpsPerCTA, mma.getWarpOrder (),
11681171 kDim , kWarp )
11691172 .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1170- auto ret = combineCtaCgaWithShape (layout, getCTALayout (dot), shape);
1171- return ret.transposeOuts ({kInner , kOuter })
1172- .reshapeOuts (
1173- {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
1173+ return combineCtaCgaWithShape (layout, getCTALayout (dot), shape);
11741174}
11751175
11761176} // anonymous namespace
@@ -1184,13 +1184,10 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
11841184 return chooseStMatrixLayoutLeadingOffset (ctx, tensorTy, swizzleByteSize);
11851185}
11861186
1187- LinearLayout chooseLdMatrixLayout (MLIRContext *ctx, Attribute sharedEnc,
1188- Attribute dotEnc, ArrayRef<int64_t > shape) {
1189- auto shared = cast<SharedEncodingAttr>(sharedEnc);
1190- auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1191- assert (!shared.getHasLeadingOffset () &&
1192- " Ldmatrix does not support leading offset yet" );
1193- return chooseLdMatrixLayoutNoLeadingOffset (ctx, shared, dot, shape);
1187+ LinearLayout chooseLdMatrixLayout (Attribute enc, ArrayRef<int64_t > shape,
1188+ bool needTrans, int32_t elemBitWidth) {
1189+ auto dot = cast<DotOperandEncodingAttr>(enc);
1190+ return chooseDotLdMatrixLayout (dot, shape, needTrans, elemBitWidth);
11941191}
11951192
11961193} // namespace mlir::triton::gpu
0 commit comments