@@ -1157,7 +1157,6 @@ struct AsyncCopyGlobalToLocalOpConversion
1157
1157
auto srcTy = op.getSrc ().getType ();
1158
1158
auto dstTy = op.getResult ().getType ();
1159
1159
auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
1160
- auto srcLayout = srcTy.getEncoding ();
1161
1160
1162
1161
Value llDst = adaptor.getResult ();
1163
1162
Value llSrc = adaptor.getSrc ();
@@ -1167,27 +1166,40 @@ struct AsyncCopyGlobalToLocalOpConversion
1167
1166
// %src
1168
1167
auto srcElems = unpackLLElements (loc, llSrc, rewriter);
1169
1168
1170
- // %dst
1171
- auto smemObj =
1172
- getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1173
1169
// %mask
1174
1170
SmallVector<Value> maskElems;
1175
1171
if (llMask) {
1176
1172
maskElems = unpackLLElements (loc, llMask, rewriter);
1177
1173
assert (srcElems.size () == maskElems.size ());
1178
1174
}
1179
1175
1176
+ // We assume other = 0, see XXX(Keren) below
1180
1177
// %other
1181
- SmallVector<Value> otherElems;
1182
- if (llOther) {
1183
- // FIXME(Keren): assume other is 0 for now.
1184
- //
1185
- // It's not necessary for now because the pipeline pass will skip
1186
- // generating insert_slice_async if the load op has any "other" tensor.
1187
- otherElems = unpackLLElements (loc, llOther, rewriter);
1188
- assert (srcElems.size () == otherElems.size ());
1178
+ // SmallVector<Value> otherElems;
1179
+ // if (llOther) {
1180
+ // otherElems = unpackLLElements(loc, llOther, rewriter);
1181
+ // assert(srcElems.size() == otherElems.size());
1182
+ // }
1183
+
1184
+ // zip(src, mask)
1185
+ SmallVector<Value> vals;
1186
+ auto ptrTy = srcElems[0 ].getType ();
1187
+ auto structTy =
1188
+ LLVM::LLVMStructType::getLiteral (ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189
+ for (int i = 0 ; i < srcElems.size (); i++) {
1190
+ Value packedArr = rewriter.create <LLVM::UndefOp>(loc, structTy);
1191
+ packedArr = b.insert_val (packedArr, srcElems[i], 0 );
1192
+ auto maskElem = llMask ? maskElems[i] : b.false_val ();
1193
+ packedArr = b.insert_val (packedArr, maskElem, 1 );
1194
+ vals.push_back (packedArr);
1189
1195
}
1190
1196
1197
+ // Remove broadcasted registers
1198
+ auto srcLayout = ttg::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
1199
+ auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
1200
+ srcLayout = removeBroadcastSrc.apply (srcLayout);
1201
+ vals = removeBroadcastSrc.apply (vals);
1202
+
1191
1203
// We can load N elements at a time if:
1192
1204
// 1. Every group of N source pointers are contiguous. For example, if
1193
1205
// N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1198,25 +1210,16 @@ struct AsyncCopyGlobalToLocalOpConversion
1198
1210
if (mask) {
1199
1211
maxVec = std::min (maxVec, getMaskAlignment (mask));
1200
1212
}
1213
+ // The maximum vector size is 128 bits on NVIDIA GPUs.
1214
+ maxVec = std::min (maxVec, 128 / resElemTy.getIntOrFloatBitWidth ());
1201
1215
1202
- // Addresses to store into, one per `vecTy`.
1203
- VectorType vecTy;
1204
- SmallVector<Value> shmemAddrs;
1205
- bool ok = emitTransferBetweenRegistersAndShared (
1206
- srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207
- [&](VectorType vecTy_, Value shmemAddr) {
1208
- vecTy = vecTy_;
1209
- shmemAddrs.push_back (shmemAddr);
1210
- });
1211
- assert (ok);
1212
-
1213
- int vecBytes = vecTy.getNumElements () * vecTy.getElementTypeBitWidth () / 8 ;
1214
- assert (llvm::isPowerOf2_32 (vecBytes));
1216
+ int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth () / 8 ;
1215
1217
if (vecBytes < 4 ) {
1216
1218
return emitError (loc, " cp.async does not support transfers smaller than "
1217
1219
" 4 bytes; calculated this as " )
1218
1220
<< vecBytes << " bytes" ;
1219
1221
}
1222
+ assert (vecBytes == 16 || vecBytes == 8 || vecBytes == 4 );
1220
1223
1221
1224
auto freeVarMasks = getFreeVariableMasks (srcTy);
1222
1225
// NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1225,52 +1228,63 @@ struct AsyncCopyGlobalToLocalOpConversion
1225
1228
freeVarMasks[str_attr (" block" )] = 0 ;
1226
1229
Value threadPred =
1227
1230
emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
1228
- uint32_t regMask = freeVarMasks[str_attr (" reg" )];
1229
1231
1230
- for (int i = 0 ; i < shmemAddrs.size (); i++) {
1231
- // It's possible that vecTy is larger than 128 bits, in which case we have
1232
- // to use multiple cp.async instructions.
1233
- int wordBytes = std::min (vecBytes, 16 );
1234
- int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth ();
1235
- int numWordsInVec = std::max (1 , vecBytes / wordBytes);
1236
- for (int j = 0 ; j < numWordsInVec; j++) {
1237
- int elemIdx = i * vecTy.getNumElements () + j * wordElems;
1238
-
1239
- if (!isCanonicalIndex (elemIdx, regMask)) {
1240
- continue ; // Skip redundant registers
1241
- }
1232
+ auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool (llMask)](
1233
+ ConversionPatternRewriter &rewriter, Location loc,
1234
+ ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235
+ VectorType vecTy) -> SmallVector<Value> {
1236
+ assert (isa<VectorType>(vecTy));
1237
+ auto *ctx = rewriter.getContext ();
1238
+ auto elemTy = vecTy.getElementType ();
1239
+ auto nBytes = vecTy.getNumElements () * elemTy.getIntOrFloatBitWidth () / 8 ;
1240
+ assert (nBytes == 16 || nBytes == 8 || nBytes == 4 );
1241
+ // Tune CG and CA.
1242
+ CacheModifier srcCacheModifier =
1243
+ nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244
+
1245
+ auto structElem = vals[startIdx];
1246
+ auto srcElem = b.extract_val (ptrTy, structElem, 0 );
1247
+ auto maskElem = b.extract_val (i1_ty, structElem, 1 );
1242
1248
1243
- // Tune CG and CA.
1244
- CacheModifier srcCacheModifier =
1245
- wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246
- assert (wordBytes == 16 || wordBytes == 8 || wordBytes == 4 );
1247
-
1248
- PTXBuilder ptxBuilder;
1249
- auto ©AsyncOp =
1250
- *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1251
- auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddrs[i], " r" ,
1252
- /* offset=*/ j * wordBytes);
1253
- auto *srcOperand = ptxBuilder.newAddrOperand (srcElems[elemIdx], " l" );
1254
- auto *copySize = ptxBuilder.newConstantOperand (wordBytes);
1255
- auto *srcSize = copySize;
1256
- if (op.getMask ()) {
1257
- // We don't use predicate in this case, setting src-size to 0
1258
- // if there's any mask. cp.async will automatically fill the
1259
- // remaining slots with 0 if cp-size > src-size.
1260
- // XXX(Keren): Always assume other = 0 for now.
1261
- // When 'other != 0' is supported, we will need to fold the
1262
- // op.getMask() and redundantDataMask() into the same predicate, the
1263
- // way it is done for LoadOp.
1264
- auto selectOp =
1265
- b.select (maskElems[elemIdx], b.i32_val (wordBytes), b.i32_val (0 ));
1266
- srcSize = ptxBuilder.newOperand (selectOp, " r" );
1267
- }
1268
-
1269
- copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1270
- .maybePredicate (threadPred);
1271
- ptxBuilder.launch (rewriter, loc, void_ty (getContext ()));
1249
+ PTXBuilder ptxBuilder;
1250
+ auto ©AsyncOp =
1251
+ *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1252
+ auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddr, " r" );
1253
+ auto *srcOperand = ptxBuilder.newAddrOperand (srcElem, " l" );
1254
+ auto *copySize = ptxBuilder.newConstantOperand (nBytes);
1255
+ auto *srcSize = copySize;
1256
+ if (hasMask) {
1257
+ // We don't use predicate in this case, setting src-size to 0
1258
+ // if there's any mask. cp.async will automatically fill the
1259
+ // remaining slots with 0 if cp-size > src-size.
1260
+ // XXX(Keren): Always assume other = 0 for now.
1261
+ // When 'other != 0' is supported, we will need to fold the
1262
+ // op.getMask() and redundantDataMask() into the same predicate, the
1263
+ // way it is done for LoadOp.
1264
+ auto selectOp = b.select (maskElem, b.i32_val (nBytes), b.i32_val (0 ));
1265
+ srcSize = ptxBuilder.newOperand (selectOp, " r" );
1272
1266
}
1267
+ copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1268
+ .maybePredicate (threadPred);
1269
+ ptxBuilder.launch (rewriter, loc, void_ty (ctx));
1270
+ return {};
1271
+ };
1272
+
1273
+ // %dst
1274
+ auto smemObj =
1275
+ getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1276
+ auto smemLayout =
1277
+ ttg::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
1278
+ auto cvt = srcLayout.invertAndCompose (smemLayout);
1279
+ if (!cvt.isTrivialOver ({str_attr (" block" )})) {
1280
+ return emitError (loc,
1281
+ " cp.async does not support non-trivial block dimension" );
1273
1282
}
1283
+ cvt = cvt.sublayout (
1284
+ {str_attr (" register" ), str_attr (" lane" ), str_attr (" warp" )},
1285
+ {str_attr (" offset" )});
1286
+ lowerLdSt (loc, ctx, cvt, vals, resElemTy, smemObj.getBase (), rewriter,
1287
+ targetInfo, maxVec, emitCpAsync);
1274
1288
1275
1289
// Drop the result token.
1276
1290
Value zero = rewriter.create <LLVM::ConstantOp>(
0 commit comments