@@ -992,53 +992,195 @@ static void copyScales(ConversionPatternRewriter &rewriter, Location loc,
992992 createCopy (repMorN, repK);
993993}
994994
995+ static std::optional<std::tuple<int32_t , LinearLayout, LinearLayout,
996+ SmallVector<int64_t >, int32_t >>
997+ getSwizzling (MemDescType shmemTy, MemDescType tmemTy) {
998+ // cvt is a map from Tmem to Shmem
999+ auto tmemLl = toLinearLayout (tmemTy);
1000+ auto shmemLl = toLinearLayout (shmemTy);
1001+ auto inDimNames = to_vector (tmemLl.getInDimNames ());
1002+ auto *ctx = inDimNames[0 ].getContext ();
1003+ assert (shmemLl.getInDimSize (str_attr (" block" )) == 1 && " NYI" );
1004+ auto kOffset = str_attr (" offset" );
1005+ auto kRow = str_attr (" row" );
1006+ auto kCol = str_attr (" col" );
1007+ shmemLl = shmemLl.sublayout ({kOffset }, to_vector (shmemLl.getOutDimNames ()));
1008+ auto cvt = tmemLl.invertAndCompose (shmemLl);
1009+
1010+ int32_t bitwidth = tmemTy.getElementType ().getIntOrFloatBitWidth ();
1011+
1012+ // Check if the layout is large enough as to check SBO
1013+ // TODO Move to the verifier
1014+ if (shmemLl.getOutDimSizeLog2 (str_attr (" dim0" )) < 4 ) {
1015+ return std::nullopt ;
1016+ }
1017+ // TODO We may need to be careful here if we ever want to support fp4 padded
1018+ // layouts
1019+ if (!shmemLl.isInvertible ()) {
1020+ return std::nullopt ;
1021+ }
1022+
1023+ // This will be SBO for k-Contiguous layouts (like the ones used in
1024+ // tcgen05.cp)
1025+ auto sbo =
1026+ shmemLl.invert ().getBasis (str_attr (" dim0" ), /* log2(8)=*/ 3 , kOffset );
1027+
1028+ // TODO hardcoded to 128x256b for now
1029+ const SmallVector<int64_t > instrShape = {128 , 256 / bitwidth};
1030+ // TODO Move to the verifier perhaps
1031+ // Can we move the tile?
1032+ // TODO We should be able to move any descriptor tile with 128x256b
1033+ // (or 128x128b for unswizzled when it just has one tile)
1034+ for (auto [inDimName, instrSize] : llvm::zip (inDimNames, instrShape)) {
1035+ if (cvt.getInDimSize (inDimName) < instrSize) {
1036+ return std::nullopt ;
1037+ }
1038+ }
1039+
1040+ auto CTALayout = getCTALayout (shmemTy.getEncoding ());
1041+
1042+ for (int swizzling : {0 , 32 , 64 , 128 }) {
1043+ // r = 0, 1, 2, 3
1044+ auto shmemEnc =
1045+ NVMMASharedEncodingAttr::get (ctx, swizzling, /* transposed=*/ false ,
1046+ bitwidth, /* fp4Padded=*/ false , CTALayout);
1047+ auto shmemTile =
1048+ getCoreMatrixLinearLayout (shmemEnc, /* disableSwizzle=*/ false );
1049+ // getCoreMatrixLinearLayout gives the k-contiguous tile
1050+ // shmemTile is a layout onto a matrix with shape
1051+ // If swizzling != 0: 8 x (8 * swizzling / bitwidth)
1052+ // If swizzling == 0: 8 x (8 * 16 / bitwidth)
1053+ assert (shmemTile.getOutDimSize (str_attr (" dim0" )) == 8 );
1054+ assert (shmemTile.getOutDimSize (str_attr (" dim1" )) ==
1055+ 8 * std::max (16 , swizzling) / bitwidth);
1056+ // The shmemTile is mapped identically into the tmem, so we just need to
1057+ // rename the outDims in shmemTile from dim0, dim1 to row, col
1058+ auto cvtTileInverted =
1059+ LinearLayout (shmemTile.getBases (), {str_attr (" row" ), str_attr (" col" )});
1060+ // The tile should be invertible, so we consider it as a map from row, col
1061+ // to offset
1062+ // nb. Working with the map from row, col to offset is important to handle
1063+ // the tcgen05.cp instructions that do broadcasting
1064+ auto cvtTile = cvtTileInverted.invert ();
1065+ // The sbo stride shall not touch the core tile
1066+ if (sbo < cvtTile.getOutDimSize (kOffset ))
1067+ continue ;
1068+
1069+ // As we are copying instrShape[0] columns in one go, to be able to
1070+ // represent this in the descriptor, we need to have a constant "stride"
1071+ // along the row dimension from row=8 until the last row.
1072+ auto bases = cvtTile.getBases ();
1073+ for (int i = 1 ; i < instrShape[0 ] / 8 ; i *= 2 ) {
1074+ bases[kRow ].push_back ({sbo * i});
1075+ }
1076+ cvtTile = LinearLayout (bases, {{kOffset , sbo * (instrShape[0 ] / 8 )}},
1077+ /* requireSurjective=*/ false );
1078+
1079+ auto quot = divideLeft (cvt, cvtTile);
1080+ if (quot.has_value ()) {
1081+ if (auto nvmma = dyn_cast<NVMMASharedEncodingAttr>(shmemEnc)) {
1082+ assert (nvmma.getSwizzlingByteWidth () == swizzling);
1083+ }
1084+ return std::make_tuple (swizzling, *quot, cvtTile, instrShape, sbo);
1085+ }
1086+ }
1087+ return std::nullopt ;
1088+ }
1089+
9951090static void copySharedToTmem (ConversionPatternRewriter &rewriter, Location loc,
9961091 const TypeConverter *typeConverter,
9971092 triton::nvidia_gpu::TMEMCopyOp op, Value src,
998- Value dst , Value pred) {
1093+ Value baseDst , Value pred) {
9991094 auto b = TritonLLVMOpBuilder (loc, rewriter);
1095+ auto *ctx = op.getContext ();
1096+ auto kOffset = str_attr (" offset" );
1097+ auto kRow = str_attr (" row" );
1098+ auto kCol = str_attr (" col" );
1099+
10001100 MemDescType srcTy = op.getSrc ().getType ();
10011101 MemDescType dstTy = op.getDst ().getType ();
1102+
1103+ auto sharedLl = toLinearLayout (srcTy);
1104+ sharedLl =
1105+ sharedLl.sublayout ({kOffset }, to_vector (sharedLl.getOutDimNames ()));
1106+ auto tmemLl = toLinearLayout (dstTy);
1107+ auto cvt = tmemLl.invertAndCompose (sharedLl);
1108+
1109+ auto bitwidth = srcTy.getElementType ().getIntOrFloatBitWidth ();
1110+ // Need to find the shmem tile that matches
1111+ auto maybeSwizzling = getSwizzling (srcTy, dstTy);
1112+ assert (maybeSwizzling.has_value ());
1113+ auto [swizzling, quot, tile, tileShape, sbo] = std::move (*maybeSwizzling);
1114+
1115+ auto reps = zerosLike (tile) * quot;
1116+
1117+ // Get shmem ptr
1118+ // TODO We should not allow splitting along the swizzling pattern
10021119 Type elemTy = typeConverter->convertType (srcTy.getElementType ());
10031120 auto smemObj =
10041121 LLVM::getSharedMemoryObjectFromStruct (loc, src, elemTy, rewriter);
1005- Value baseSrc = smemObj.getShmemAffineBase (loc, rewriter, srcTy);
1122+ Value baseSrcInt =
1123+ b.ptrtoint (i32_ty, smemObj.getShmemAffineBase (loc, rewriter, srcTy));
1124+ // We checked in the verifier that the alignment is at least 16
1125+ Value baseSrcIntShr4 = b.lshr (baseSrcInt, b.i32_val (4 ));
1126+
1127+ // Set common fields in the SMEMDescriptor
1128+ SMEMDescriptor desc;
1129+ desc.baseAddress = 0 ;
1130+ // For K-contig, leadDimension is assumed to be 1
1131+ desc.leadDimensionBaseOffset = 1 ;
1132+ // SBO is in elements and we have to pass it to bits and right shift by 4
1133+ desc.strideDimensionBaseOffset = ((sbo * (bitwidth / 8 )) >> 4 );
1134+ desc.matrixBaseOffset = 0 ;
1135+ switch (swizzling) {
1136+ case 0 :
1137+ desc.swizzlingMode = 0 ;
1138+ break ;
1139+ case 32 :
1140+ desc.swizzlingMode = 3 ;
1141+ break ;
1142+ case 64 :
1143+ desc.swizzlingMode = 2 ;
1144+ break ;
1145+ case 128 :
1146+ desc.swizzlingMode = 1 ;
1147+ break ;
1148+ default :
1149+ llvm::report_fatal_error (" Unsupported swizzling size." );
1150+ }
10061151
1007- Value baseDst = dst;
1008- assert (srcTy.getElementType ().getIntOrFloatBitWidth () == 32 );
1009-
1010- int blockN =
1011- cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(dstTy.getEncoding ())
1012- .getBlockN ();
1013- // Currently, hardcoded to 128x256b message.
1014- std::array<int , 2 > instShape = {128 , 8 };
1015- int repNPerBlock = blockN / instShape[1 ];
1016- auto createCopy = [&](int repM, int repN) {
1017- Value zero = b.i32_val (0 );
1018- SmallVector<int64_t > shape (op.getSrc ().getType ().getShape ());
1019- DotOpMmaV5SmemLoader smemLoader = DotOpMmaV5SmemLoader (
1020- op.getSrc (), baseSrc, shape, op.getSrc ().getType ().getAllocShape (),
1021- zero, 1 , /* trans=*/ false , {128 , 8 },
1022- op.getSrc ().getType ().getElementType ().getIntOrFloatBitWidth (),
1023- rewriter, loc);
1024- for (int m = 0 ; m < repM; m++) {
1025- for (int n = 0 ; n < repN; n++) {
1026- int colIndx =
1027- (n % repNPerBlock) * instShape[1 ] +
1028- m * repNPerBlock * instShape[1 ] +
1029- (n / repNPerBlock) * (srcTy.getDimSize (0 ) / instShape[0 ]) * blockN;
1030- auto colOffset = b.i32_val (colIndx);
1031- auto tmemAddr = b.add (b.ptrtoint (i32_ty, baseDst), colOffset);
1032- Value smemDesc = smemLoader.smemLoad (m, n, rewriter, loc);
1033- createTcgen05Cp (rewriter, loc, tmemAddr, smemDesc, pred,
1034- /* scales=*/ false );
1035- }
1152+ // Make sure we don't have to iterate along the rows
1153+ assert (tile.getInDimSize (kRow ) == cvt.getInDimSize (kRow ) && " NYI" );
1154+ assert (tileShape[1 ] <= tile.getInDimSize (kCol ) && " NYI" );
1155+ int elementBytes = bitwidth / 8 ;
1156+ for (int col = 0 ; col < reps.getInDimSize (kCol );
1157+ col += tile.getInDimSize (kCol )) {
1158+ // Compute base offset for the swizzling pattern
1159+ int32_t off = reps.apply ({{kRow , 0 }, {kCol , col}})[0 ].second ;
1160+ desc.matrixBaseOffset = (off * elementBytes / 128 ) & 0x7 ;
1161+ uint64_t descBase = desc.descriptor ;
1162+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor
1163+ descBase |= (1ULL << 46 );
1164+ Value descValBase = b.int_val (64 , desc.descriptor );
1165+ for (int offset = 0 ; offset < tile.getInDimSize (kCol );
1166+ offset += tileShape[1 ]) {
1167+ // Compute total offset of the current message
1168+ int32_t totalOffElems =
1169+ cvt.apply ({{kRow , 0 }, {kCol , col + offset}})[0 ].second ;
1170+ int32_t smemByteOffset = totalOffElems * elementBytes;
1171+ int32_t smemByteOffsetShr4 = smemByteOffset >> 4 ;
1172+ // We could fold this add into the descBase if we wanted to
1173+ Value baseAddr = b.add (baseSrcIntShr4, b.i32_val (smemByteOffsetShr4));
1174+ Value baseSrcDesc = b.zext (i64_ty, b.and_ (baseAddr, b.i32_val (0x3FFF )));
1175+ // Add the base address to the descriptor
1176+ Value descVal = b.or_ (descValBase, baseSrcDesc, /* disjoint=*/ true );
1177+ auto tmemAddr =
1178+ b.or_ (b.ptrtoint (i32_ty, baseDst), b.i32_val (col + offset),
1179+ /* disjoint=*/ true );
1180+ createTcgen05Cp (rewriter, loc, tmemAddr, descVal, pred,
1181+ /* scales=*/ false );
10361182 }
1037- };
1038-
1039- int repM = srcTy.getDimSize (0 ) / instShape[0 ];
1040- int repN = srcTy.getDimSize (1 ) / instShape[1 ];
1041- createCopy (repM, repN);
1183+ }
10421184}
10431185
10441186struct TensorMemoryCopyOpConversion
@@ -1048,7 +1190,7 @@ struct TensorMemoryCopyOpConversion
10481190 LogicalResult
10491191 matchAndRewrite (triton::nvidia_gpu::TMEMCopyOp op, OpAdaptor adaptor,
10501192 ConversionPatternRewriter &rewriter) const override {
1051-
1193+ assert ( lookupNumCTAs (rewriter) == 1 && " NYI " );
10521194 Location loc = op->getLoc ();
10531195 Value pred = LLVM::NVIDIA::createElectPredicateWarp0 (loc, rewriter);
10541196 if (isa<TensorMemoryScalesEncodingAttr>(
0 commit comments