@@ -1157,6 +1157,7 @@ struct AsyncCopyGlobalToLocalOpConversion
11571157 auto srcTy = op.getSrc ().getType ();
11581158 auto dstTy = op.getResult ().getType ();
11591159 auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
1160+ auto srcLayout = srcTy.getEncoding ();
11601161
11611162 Value llDst = adaptor.getResult ();
11621163 Value llSrc = adaptor.getSrc ();
@@ -1166,40 +1167,27 @@ struct AsyncCopyGlobalToLocalOpConversion
11661167 // %src
11671168 auto srcElems = unpackLLElements (loc, llSrc, rewriter);
11681169
1170+ // %dst
1171+ auto smemObj =
1172+ getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
11691173 // %mask
11701174 SmallVector<Value> maskElems;
11711175 if (llMask) {
11721176 maskElems = unpackLLElements (loc, llMask, rewriter);
11731177 assert (srcElems.size () == maskElems.size ());
11741178 }
11751179
1176- // We assume other = 0, see XXX(Keren) below
11771180 // %other
1178- // SmallVector<Value> otherElems;
1179- // if (llOther) {
1180- // otherElems = unpackLLElements(loc, llOther, rewriter);
1181- // assert(srcElems.size() == otherElems.size());
1182- // }
1183-
1184- // zip(src, mask)
1185- SmallVector<Value> vals;
1186- auto ptrTy = srcElems[0 ].getType ();
1187- auto structTy =
1188- LLVM::LLVMStructType::getLiteral (ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189- for (int i = 0 ; i < srcElems.size (); i++) {
1190- Value packedArr = rewriter.create <LLVM::UndefOp>(loc, structTy);
1191- packedArr = b.insert_val (packedArr, srcElems[i], 0 );
1192- auto maskElem = llMask ? maskElems[i] : b.false_val ();
1193- packedArr = b.insert_val (packedArr, maskElem, 1 );
1194- vals.push_back (packedArr);
1181+ SmallVector<Value> otherElems;
1182+ if (llOther) {
1183+ // FIXME(Keren): assume other is 0 for now.
1184+ //
1185+ // It's not necessary for now because the pipeline pass will skip
1186+ // generating insert_slice_async if the load op has any "other" tensor.
1187+ otherElems = unpackLLElements (loc, llOther, rewriter);
1188+ assert (srcElems.size () == otherElems.size ());
11951189 }
11961190
1197- // Remove broadcasted registers
1198- auto srcLayout = ttg::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
1199- auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
1200- srcLayout = removeBroadcastSrc.apply (srcLayout);
1201- vals = removeBroadcastSrc.apply (vals);
1202-
12031191 // We can load N elements at a time if:
12041192 // 1. Every group of N source pointers are contiguous. For example, if
12051193 // N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1210,16 +1198,25 @@ struct AsyncCopyGlobalToLocalOpConversion
12101198 if (mask) {
12111199 maxVec = std::min (maxVec, getMaskAlignment (mask));
12121200 }
1213- // The maximum vector size is 128 bits on NVIDIA GPUs.
1214- maxVec = std::min (maxVec, 128 / resElemTy.getIntOrFloatBitWidth ());
12151201
1216- int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth () / 8 ;
1202+ // Addresses to store into, one per `vecTy`.
1203+ VectorType vecTy;
1204+ SmallVector<Value> shmemAddrs;
1205+ bool ok = emitTransferBetweenRegistersAndShared (
1206+ srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207+ [&](VectorType vecTy_, Value shmemAddr) {
1208+ vecTy = vecTy_;
1209+ shmemAddrs.push_back (shmemAddr);
1210+ });
1211+ assert (ok);
1212+
1213+ int vecBytes = vecTy.getNumElements () * vecTy.getElementTypeBitWidth () / 8 ;
1214+ assert (llvm::isPowerOf2_32 (vecBytes));
12171215 if (vecBytes < 4 ) {
12181216 return emitError (loc, " cp.async does not support transfers smaller than "
12191217 " 4 bytes; calculated this as " )
12201218 << vecBytes << " bytes" ;
12211219 }
1222- assert (vecBytes == 16 || vecBytes == 8 || vecBytes == 4 );
12231220
12241221 auto freeVarMasks = getFreeVariableMasks (srcTy);
12251222 // NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1228,63 +1225,52 @@ struct AsyncCopyGlobalToLocalOpConversion
12281225 freeVarMasks[str_attr (" block" )] = 0 ;
12291226 Value threadPred =
12301227 emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
1228+ uint32_t regMask = freeVarMasks[str_attr (" reg" )];
12311229
1232- auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool (llMask)](
1233- ConversionPatternRewriter &rewriter, Location loc,
1234- ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235- VectorType vecTy) -> SmallVector<Value> {
1236- assert (isa<VectorType>(vecTy));
1237- auto *ctx = rewriter.getContext ();
1238- auto elemTy = vecTy.getElementType ();
1239- auto nBytes = vecTy.getNumElements () * elemTy.getIntOrFloatBitWidth () / 8 ;
1240- assert (nBytes == 16 || nBytes == 8 || nBytes == 4 );
1241- // Tune CG and CA.
1242- CacheModifier srcCacheModifier =
1243- nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244-
1245- auto structElem = vals[startIdx];
1246- auto srcElem = b.extract_val (ptrTy, structElem, 0 );
1247- auto maskElem = b.extract_val (i1_ty, structElem, 1 );
1230+ for (int i = 0 ; i < shmemAddrs.size (); i++) {
1231+ // It's possible that vecTy is larger than 128 bits, in which case we have
1232+ // to use multiple cp.async instructions.
1233+ int wordBytes = std::min (vecBytes, 16 );
1234+ int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth ();
1235+ int numWordsInVec = std::max (1 , vecBytes / wordBytes);
1236+ for (int j = 0 ; j < numWordsInVec; j++) {
1237+ int elemIdx = i * vecTy.getNumElements () + j * wordElems;
1238+
1239+ if (!isCanonicalIndex (elemIdx, regMask)) {
1240+ continue ; // Skip redundant registers
1241+ }
12481242
1249- PTXBuilder ptxBuilder;
1250- auto ©AsyncOp =
1251- *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1252- auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddr, " r" );
1253- auto *srcOperand = ptxBuilder.newAddrOperand (srcElem, " l" );
1254- auto *copySize = ptxBuilder.newConstantOperand (nBytes);
1255- auto *srcSize = copySize;
1256- if (hasMask) {
1257- // We don't use predicate in this case, setting src-size to 0
1258- // if there's any mask. cp.async will automatically fill the
1259- // remaining slots with 0 if cp-size > src-size.
1260- // XXX(Keren): Always assume other = 0 for now.
1261- // When 'other != 0' is supported, we will need to fold the
1262- // op.getMask() and redundantDataMask() into the same predicate, the
1263- // way it is done for LoadOp.
1264- auto selectOp = b.select (maskElem, b.i32_val (nBytes), b.i32_val (0 ));
1265- srcSize = ptxBuilder.newOperand (selectOp, " r" );
1266- }
1267- copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1268- .maybePredicate (threadPred);
1269- ptxBuilder.launch (rewriter, loc, void_ty (ctx));
1270- return {};
1271- };
1243+ // Tune CG and CA.
1244+ CacheModifier srcCacheModifier =
1245+ wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246+ assert (wordBytes == 16 || wordBytes == 8 || wordBytes == 4 );
1247+
1248+ PTXBuilder ptxBuilder;
1249+ auto ©AsyncOp =
1250+ *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1251+ auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddrs[i], " r" ,
1252+ /* offset=*/ j * wordBytes);
1253+ auto *srcOperand = ptxBuilder.newAddrOperand (srcElems[elemIdx], " l" );
1254+ auto *copySize = ptxBuilder.newConstantOperand (wordBytes);
1255+ auto *srcSize = copySize;
1256+ if (op.getMask ()) {
1257+ // We don't use predicate in this case, setting src-size to 0
1258+ // if there's any mask. cp.async will automatically fill the
1259+ // remaining slots with 0 if cp-size > src-size.
1260+ // XXX(Keren): Always assume other = 0 for now.
1261+ // When 'other != 0' is supported, we will need to fold the
1262+ // op.getMask() and redundantDataMask() into the same predicate, the
1263+ // way it is done for LoadOp.
1264+ auto selectOp =
1265+ b.select (maskElems[elemIdx], b.i32_val (wordBytes), b.i32_val (0 ));
1266+ srcSize = ptxBuilder.newOperand (selectOp, " r" );
1267+ }
12721268
1273- // %dst
1274- auto smemObj =
1275- getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1276- auto smemLayout =
1277- ttg::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
1278- auto cvt = srcLayout.invertAndCompose (smemLayout);
1279- if (!cvt.isTrivialOver ({str_attr (" block" )})) {
1280- return emitError (loc,
1281- " cp.async does not support non-trivial block dimension" );
1269+ copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1270+ .maybePredicate (threadPred);
1271+ ptxBuilder.launch (rewriter, loc, void_ty (getContext ()));
1272+ }
12821273 }
1283- cvt = cvt.sublayout (
1284- {str_attr (" register" ), str_attr (" lane" ), str_attr (" warp" )},
1285- {str_attr (" offset" )});
1286- lowerLdSt (loc, ctx, cvt, vals, resElemTy, smemObj.getBase (), rewriter,
1287- targetInfo, maxVec, emitCpAsync);
12881274
12891275 // Drop the result token.
12901276 Value zero = rewriter.create <LLVM::ConstantOp>(
0 commit comments