@@ -1157,6 +1157,7 @@ struct AsyncCopyGlobalToLocalOpConversion
1157
1157
auto srcTy = op.getSrc ().getType ();
1158
1158
auto dstTy = op.getResult ().getType ();
1159
1159
auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
1160
+ auto srcLayout = srcTy.getEncoding ();
1160
1161
1161
1162
Value llDst = adaptor.getResult ();
1162
1163
Value llSrc = adaptor.getSrc ();
@@ -1166,40 +1167,27 @@ struct AsyncCopyGlobalToLocalOpConversion
1166
1167
// %src
1167
1168
auto srcElems = unpackLLElements (loc, llSrc, rewriter);
1168
1169
1170
+ // %dst
1171
+ auto smemObj =
1172
+ getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1169
1173
// %mask
1170
1174
SmallVector<Value> maskElems;
1171
1175
if (llMask) {
1172
1176
maskElems = unpackLLElements (loc, llMask, rewriter);
1173
1177
assert (srcElems.size () == maskElems.size ());
1174
1178
}
1175
1179
1176
- // We assume other = 0, see XXX(Keren) below
1177
1180
// %other
1178
- // SmallVector<Value> otherElems;
1179
- // if (llOther) {
1180
- // otherElems = unpackLLElements(loc, llOther, rewriter);
1181
- // assert(srcElems.size() == otherElems.size());
1182
- // }
1183
-
1184
- // zip(src, mask)
1185
- SmallVector<Value> vals;
1186
- auto ptrTy = srcElems[0 ].getType ();
1187
- auto structTy =
1188
- LLVM::LLVMStructType::getLiteral (ctx, ArrayRef<Type>{ptrTy, i1_ty});
1189
- for (int i = 0 ; i < srcElems.size (); i++) {
1190
- Value packedArr = rewriter.create <LLVM::UndefOp>(loc, structTy);
1191
- packedArr = b.insert_val (packedArr, srcElems[i], 0 );
1192
- auto maskElem = llMask ? maskElems[i] : b.false_val ();
1193
- packedArr = b.insert_val (packedArr, maskElem, 1 );
1194
- vals.push_back (packedArr);
1181
+ SmallVector<Value> otherElems;
1182
+ if (llOther) {
1183
+ // FIXME(Keren): assume other is 0 for now.
1184
+ //
1185
+ // It's not necessary for now because the pipeline pass will skip
1186
+ // generating insert_slice_async if the load op has any "other" tensor.
1187
+ otherElems = unpackLLElements (loc, llOther, rewriter);
1188
+ assert (srcElems.size () == otherElems.size ());
1195
1189
}
1196
1190
1197
- // Remove broadcasted registers
1198
- auto srcLayout = ttg::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
1199
- auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
1200
- srcLayout = removeBroadcastSrc.apply (srcLayout);
1201
- vals = removeBroadcastSrc.apply (vals);
1202
-
1203
1191
// We can load N elements at a time if:
1204
1192
// 1. Every group of N source pointers are contiguous. For example, if
1205
1193
// N=2, then the pointers should be [x, x+1, y, y+1, ...].
@@ -1210,16 +1198,25 @@ struct AsyncCopyGlobalToLocalOpConversion
1210
1198
if (mask) {
1211
1199
maxVec = std::min (maxVec, getMaskAlignment (mask));
1212
1200
}
1213
- // The maximum vector size is 128 bits on NVIDIA GPUs.
1214
- maxVec = std::min (maxVec, 128 / resElemTy.getIntOrFloatBitWidth ());
1215
1201
1216
- int vecBytes = maxVec * resElemTy.getIntOrFloatBitWidth () / 8 ;
1202
+ // Addresses to store into, one per `vecTy`.
1203
+ VectorType vecTy;
1204
+ SmallVector<Value> shmemAddrs;
1205
+ bool ok = emitTransferBetweenRegistersAndShared (
1206
+ srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo,
1207
+ [&](VectorType vecTy_, Value shmemAddr) {
1208
+ vecTy = vecTy_;
1209
+ shmemAddrs.push_back (shmemAddr);
1210
+ });
1211
+ assert (ok);
1212
+
1213
+ int vecBytes = vecTy.getNumElements () * vecTy.getElementTypeBitWidth () / 8 ;
1214
+ assert (llvm::isPowerOf2_32 (vecBytes));
1217
1215
if (vecBytes < 4 ) {
1218
1216
return emitError (loc, " cp.async does not support transfers smaller than "
1219
1217
" 4 bytes; calculated this as " )
1220
1218
<< vecBytes << " bytes" ;
1221
1219
}
1222
- assert (vecBytes == 16 || vecBytes == 8 || vecBytes == 4 );
1223
1220
1224
1221
auto freeVarMasks = getFreeVariableMasks (srcTy);
1225
1222
// NOTE(@peterbell10): We load redundant data on different CTAs, so the data
@@ -1228,63 +1225,52 @@ struct AsyncCopyGlobalToLocalOpConversion
1228
1225
freeVarMasks[str_attr (" block" )] = 0 ;
1229
1226
Value threadPred =
1230
1227
emitRedundantThreadPredicate (freeVarMasks, rewriter, loc, targetInfo);
1228
+ uint32_t regMask = freeVarMasks[str_attr (" reg" )];
1231
1229
1232
- auto emitCpAsync = [&b, threadPred, ptrTy, hasMask = bool (llMask)](
1233
- ConversionPatternRewriter &rewriter, Location loc,
1234
- ArrayRef<Value> vals, Value shmemAddr, int startIdx,
1235
- VectorType vecTy) -> SmallVector<Value> {
1236
- assert (isa<VectorType>(vecTy));
1237
- auto *ctx = rewriter.getContext ();
1238
- auto elemTy = vecTy.getElementType ();
1239
- auto nBytes = vecTy.getNumElements () * elemTy.getIntOrFloatBitWidth () / 8 ;
1240
- assert (nBytes == 16 || nBytes == 8 || nBytes == 4 );
1241
- // Tune CG and CA.
1242
- CacheModifier srcCacheModifier =
1243
- nBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1244
-
1245
- auto structElem = vals[startIdx];
1246
- auto srcElem = b.extract_val (ptrTy, structElem, 0 );
1247
- auto maskElem = b.extract_val (i1_ty, structElem, 1 );
1230
+ for (int i = 0 ; i < shmemAddrs.size (); i++) {
1231
+ // It's possible that vecTy is larger than 128 bits, in which case we have
1232
+ // to use multiple cp.async instructions.
1233
+ int wordBytes = std::min (vecBytes, 16 );
1234
+ int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth ();
1235
+ int numWordsInVec = std::max (1 , vecBytes / wordBytes);
1236
+ for (int j = 0 ; j < numWordsInVec; j++) {
1237
+ int elemIdx = i * vecTy.getNumElements () + j * wordElems;
1238
+
1239
+ if (!isCanonicalIndex (elemIdx, regMask)) {
1240
+ continue ; // Skip redundant registers
1241
+ }
1248
1242
1249
- PTXBuilder ptxBuilder;
1250
- auto ©AsyncOp =
1251
- *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1252
- auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddr, " r" );
1253
- auto *srcOperand = ptxBuilder.newAddrOperand (srcElem, " l" );
1254
- auto *copySize = ptxBuilder.newConstantOperand (nBytes);
1255
- auto *srcSize = copySize;
1256
- if (hasMask) {
1257
- // We don't use predicate in this case, setting src-size to 0
1258
- // if there's any mask. cp.async will automatically fill the
1259
- // remaining slots with 0 if cp-size > src-size.
1260
- // XXX(Keren): Always assume other = 0 for now.
1261
- // When 'other != 0' is supported, we will need to fold the
1262
- // op.getMask() and redundantDataMask() into the same predicate, the
1263
- // way it is done for LoadOp.
1264
- auto selectOp = b.select (maskElem, b.i32_val (nBytes), b.i32_val (0 ));
1265
- srcSize = ptxBuilder.newOperand (selectOp, " r" );
1266
- }
1267
- copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1268
- .maybePredicate (threadPred);
1269
- ptxBuilder.launch (rewriter, loc, void_ty (ctx));
1270
- return {};
1271
- };
1243
+ // Tune CG and CA.
1244
+ CacheModifier srcCacheModifier =
1245
+ wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA;
1246
+ assert (wordBytes == 16 || wordBytes == 8 || wordBytes == 4 );
1247
+
1248
+ PTXBuilder ptxBuilder;
1249
+ auto ©AsyncOp =
1250
+ *ptxBuilder.create <PTXCpAsyncLoadInstr>(srcCacheModifier);
1251
+ auto *dstOperand = ptxBuilder.newAddrOperand (shmemAddrs[i], " r" ,
1252
+ /* offset=*/ j * wordBytes);
1253
+ auto *srcOperand = ptxBuilder.newAddrOperand (srcElems[elemIdx], " l" );
1254
+ auto *copySize = ptxBuilder.newConstantOperand (wordBytes);
1255
+ auto *srcSize = copySize;
1256
+ if (op.getMask ()) {
1257
+ // We don't use predicate in this case, setting src-size to 0
1258
+ // if there's any mask. cp.async will automatically fill the
1259
+ // remaining slots with 0 if cp-size > src-size.
1260
+ // XXX(Keren): Always assume other = 0 for now.
1261
+ // When 'other != 0' is supported, we will need to fold the
1262
+ // op.getMask() and redundantDataMask() into the same predicate, the
1263
+ // way it is done for LoadOp.
1264
+ auto selectOp =
1265
+ b.select (maskElems[elemIdx], b.i32_val (wordBytes), b.i32_val (0 ));
1266
+ srcSize = ptxBuilder.newOperand (selectOp, " r" );
1267
+ }
1272
1268
1273
- // %dst
1274
- auto smemObj =
1275
- getSharedMemoryObjectFromStruct (loc, llDst, resElemTy, rewriter);
1276
- auto smemLayout =
1277
- ttg::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
1278
- auto cvt = srcLayout.invertAndCompose (smemLayout);
1279
- if (!cvt.isTrivialOver ({str_attr (" block" )})) {
1280
- return emitError (loc,
1281
- " cp.async does not support non-trivial block dimension" );
1269
+ copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1270
+ .maybePredicate (threadPred);
1271
+ ptxBuilder.launch (rewriter, loc, void_ty (getContext ()));
1272
+ }
1282
1273
}
1283
- cvt = cvt.sublayout (
1284
- {str_attr (" register" ), str_attr (" lane" ), str_attr (" warp" )},
1285
- {str_attr (" offset" )});
1286
- lowerLdSt (loc, ctx, cvt, vals, resElemTy, smemObj.getBase (), rewriter,
1287
- targetInfo, maxVec, emitCpAsync);
1288
1274
1289
1275
// Drop the result token.
1290
1276
Value zero = rewriter.create <LLVM::ConstantOp>(
0 commit comments