@@ -1157,7 +1157,6 @@ struct AsyncCopyGlobalToLocalOpConversion
11571157 auto srcTy = op.getSrc ().getType ();
11581158 auto dstTy = op.getResult ().getType ();
11591159 auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
1160- auto srcLayout = srcTy.getEncoding ();
11611160
11621161 Value llDst = adaptor.getResult ();
11631162 Value llSrc = adaptor.getSrc ();
@@ -1167,27 +1166,40 @@ struct AsyncCopyGlobalToLocalOpConversion
11671166 // %src
11681167 auto srcElems = unpackLLElements (loc, llSrc, rewriter);
11691168
1170- // %dst
1171- auto smemObj =
1172- getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
11731169 // %mask
11741170 SmallVector<Value> maskElems;
11751171 if (llMask) {
11761172 maskElems = unpackLLElements (loc, llMask, rewriter);
11771173 assert (srcElems.size () == maskElems.size ());
11781174 }
11791175
1176+ // We assume other = 0, see XXX(Keren) below
11801177 // %other
1181- SmallVector<Value> otherElems;
1182- if (llOther) {
1183- // FIXME(Keren): assume other is 0 for now.
1184- //
1185- // It's not necessary for now because the pipeline pass will skip
1186- // generating insert_slice_async if the load op has any "other" tensor.
1187- otherElems = unpackLLElements (loc, llOther, rewriter);
1188- assert (srcElems.size () == otherElems.size ());
1178+ // SmallVector<Value> otherElems;
1179+ // if (llOther) {
1180+ // otherElems = unpackLLElements(loc, llOther, rewriter);
1181+ // assert(srcElems.size() == otherElems.size());
1182+ // }
1183+
1184+ // zip(src, mask)
1185+ SmallVector<Value> vals;
1186+ auto ptrTy = srcElems[0 ].getType ();
1187+ auto structTy =
1188+ LLVM::LLVMStructType::getLiteral (ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189+ for (int i = 0 ; i < srcElems.size (); i++) {
1190+ Value packedArr = rewriter.create <LLVM::UndefOp>(loc, structTy);
1191+ packedArr = b.insert_val (packedArr, srcElems[i], 0 );
1192+ auto maskElem = llMask ? maskElems[i] : b.false_val ();
1193+ packedArr = b.insert_val (packedArr, maskElem, 1 );
1194+ vals.push_back (packedArr);
11891195 }
11901196
1197+ // Remove broadcasted registers
1198+ auto srcLayout = ttg::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
1199+ auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
1200+ srcLayout = removeBroadcastSrc.apply (srcLayout);
1201+ vals = removeBroadcastSrc.apply (vals);
1202+
11911203 // We can load N elements at a time if:
11921204 // 1. Every group of N source pointers are contiguous. For example, if
11931205 // N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1198,25 +1210,16 @@ struct AsyncCopyGlobalToLocalOpConversion
11981210 if (mask) {
11991211 maxVec = std::min (maxVec, getMaskAlignment (mask));
12001212 }
1213+ // The maximum vector size is 128 bits on NVIDIA GPUs.
1214+ maxVec = std::min (maxVec, 128 / resElemTy.getIntOrFloatBitWidth ());
12011215
1202- // Addresses to store into, one per `vecTy`.
1203- VectorType vecTy;
1204- SmallVector<Value> shmemAddrs;
1205- bool ok = emitTransferBetweenRegistersAndShared (
1206- srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207- [&](VectorType vecTy_, Value shmemAddr) {
1208- vecTy = vecTy_;
1209- shmemAddrs.push_back (shmemAddr);
1210- });
1211- assert (ok);
1212-
1213- int vecBytes = vecTy.getNumElements () * vecTy.getElementTypeBitWidth () / 8 ;
1214- assert (llvm::isPowerOf2_32 (vecBytes));
1216+ int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth () / 8 ;
12151217 if (vecBytes < 4 ) {
12161218 return emitError (loc, " cp.async does not support transfers smaller than "
12171219 " 4 bytes; calculated this as " )
12181220 << vecBytes << " bytes" ;
12191221 }
1222+ assert (vecBytes == 16 || vecBytes == 8 || vecBytes == 4 );
12201223
12211224 auto freeVarMasks = getFreeVariableMasks (srcTy);
12221225 // NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1225,52 +1228,63 @@ struct AsyncCopyGlobalToLocalOpConversion
12251228 freeVarMasks[str_attr (" block" )] = 0 ;
12261229 Value threadPred =
12271230 emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
1228- uint32_t regMask = freeVarMasks[str_attr (" reg" )];
12291231
1230- for (int i = 0 ; i < shmemAddrs.size (); i++) {
1231- // It's possible that vecTy is larger than 128 bits, in which case we have
1232- // to use multiple cp.async instructions.
1233- int wordBytes = std::min (vecBytes, 16 );
1234- int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth ();
1235- int numWordsInVec = std::max (1 , vecBytes / wordBytes);
1236- for (int j = 0 ; j < numWordsInVec; j++) {
1237- int elemIdx = i * vecTy.getNumElements () + j * wordElems;
1238-
1239- if (!isCanonicalIndex (elemIdx, regMask)) {
1240- continue ; // Skip redundant registers
1241- }
1232+ auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool (llMask)](
1233+ ConversionPatternRewriter &rewriter, Location loc,
1234+ ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235+ VectorType vecTy) -> SmallVector<Value> {
1236+ assert (isa<VectorType>(vecTy));
1237+ auto *ctx = rewriter.getContext ();
1238+ auto elemTy = vecTy.getElementType ();
1239+ auto nBytes = vecTy.getNumElements () * elemTy.getIntOrFloatBitWidth () / 8 ;
1240+ assert (nBytes == 16 || nBytes == 8 || nBytes == 4 );
1241+ // Tune CG and CA.
1242+ CacheModifier srcCacheModifier =
1243+ nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244+
1245+ auto structElem = vals[startIdx];
1246+ auto srcElem = b.extract_val (ptrTy, structElem, 0 );
1247+ auto maskElem = b.extract_val (i1_ty, structElem, 1 );
12421248
1243- // Tune CG and CA.
1244- CacheModifier srcCacheModifier =
1245- wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246- assert (wordBytes == 16 || wordBytes == 8 || wordBytes == 4 );
1247-
1248- PTXBuilder ptxBuilder;
1249- auto ©AsyncOp =
1250- *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1251- auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddrs[i], " r" ,
1252- /* offset=*/ j * wordBytes);
1253- auto *srcOperand = ptxBuilder.newAddrOperand (srcElems[elemIdx], " l" );
1254- auto *copySize = ptxBuilder.newConstantOperand (wordBytes);
1255- auto *srcSize = copySize;
1256- if (op.getMask ()) {
1257- // We don't use predicate in this case, setting src-size to 0
1258- // if there's any mask. cp.async will automatically fill the
1259- // remaining slots with 0 if cp-size > src-size.
1260- // XXX(Keren): Always assume other = 0 for now.
1261- // When 'other != 0' is supported, we will need to fold the
1262- // op.getMask() and redundantDataMask() into the same predicate, the
1263- // way it is done for LoadOp.
1264- auto selectOp =
1265- b.select (maskElems[elemIdx], b.i32_val (wordBytes), b.i32_val (0 ));
1266- srcSize = ptxBuilder.newOperand (selectOp, " r" );
1267- }
1268-
1269- copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1270- .maybePredicate (threadPred);
1271- ptxBuilder.launch (rewriter, loc, void_ty (getContext ()));
1249+ PTXBuilder ptxBuilder;
1250+ auto ©AsyncOp =
1251+ *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1252+ auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddr, " r" );
1253+ auto *srcOperand = ptxBuilder.newAddrOperand (srcElem, " l" );
1254+ auto *copySize = ptxBuilder.newConstantOperand (nBytes);
1255+ auto *srcSize = copySize;
1256+ if (hasMask) {
1257+ // We don't use predicate in this case, setting src-size to 0
1258+ // if there's any mask. cp.async will automatically fill the
1259+ // remaining slots with 0 if cp-size > src-size.
1260+ // XXX(Keren): Always assume other = 0 for now.
1261+ // When 'other != 0' is supported, we will need to fold the
1262+ // op.getMask() and redundantDataMask() into the same predicate, the
1263+ // way it is done for LoadOp.
1264+ auto selectOp = b.select (maskElem, b.i32_val (nBytes), b.i32_val (0 ));
1265+ srcSize = ptxBuilder.newOperand (selectOp, " r" );
12721266 }
1267+ copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1268+ .maybePredicate (threadPred);
1269+ ptxBuilder.launch (rewriter, loc, void_ty (ctx));
1270+ return {};
1271+ };
1272+
1273+ // %dst
1274+ auto smemObj =
1275+ getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1276+ auto smemLayout =
1277+ ttg::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
1278+ auto cvt = srcLayout.invertAndCompose (smemLayout);
1279+ if (!cvt.isTrivialOver ({str_attr (" block" )})) {
1280+ return emitError (loc,
1281+ " cp.async does not support non-trivial block dimension" );
12731282 }
1283+ cvt = cvt.sublayout (
1284+ {str_attr (" register" ), str_attr (" lane" ), str_attr (" warp" )},
1285+ {str_attr (" offset" )});
1286+ lowerLdSt (loc, ctx, cvt, vals, resElemTy, smemObj.getBase (), rewriter,
1287+ targetInfo, maxVec, emitCpAsync);
12741288
12751289 // Drop the result token.
12761290 Value zero = rewriter.create <LLVM::ConstantOp>(
0 commit comments