@@ -50,7 +50,6 @@ using namespace NVVM;
5050
5151// This verifier is shared among the following Ops:
5252// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
53- // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
5453// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
5554static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
5655 bool isIm2Col,
@@ -98,11 +97,47 @@ LogicalResult CpAsyncOp::verify() {
9897 return success ();
9998}
10099
100+ // This verify params can be shared across TMA Load and Prefetch Ops.
101+ static LogicalResult verifyTMALoadParams (size_t tensorDims, size_t numIm2colOff,
102+ TMALoadMode mode, Location loc) {
103+ if (tensorDims < 1 || tensorDims > 5 )
104+ return emitError (loc, " expects coordinates between 1 to 5 dimension" );
105+
106+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
107+ size_t expectedIm2colOff) -> LogicalResult {
108+ if (isIm2col && (tensorDims < 3 ))
109+ return emitError (loc)
110+ << " to use " << stringifyEnum (mode)
111+ << " mode, the tensor has to be at least 3-dimensional" ;
112+
113+ if (numIm2colOff != expectedIm2colOff)
114+ return emitError (loc) << " im2col offsets expected " << expectedIm2colOff
115+ << " (provided " << numIm2colOff << " )" ;
116+
117+ return success ();
118+ };
119+
120+ switch (mode) {
121+ case TMALoadMode::TILE:
122+ return checkTMALoadParams (mode, false , 0 );
123+ case TMALoadMode::IM2COL:
124+ return checkTMALoadParams (mode, true , tensorDims - 2 );
125+ case TMALoadMode::IM2COL_W:
126+ case TMALoadMode::IM2COL_W_128:
127+ return checkTMALoadParams (mode, true , 2 );
128+ case TMALoadMode::TILE_GATHER4:
129+ return (tensorDims == 5 )
130+ ? checkTMALoadParams (mode, false , 0 )
131+ : emitError (loc, " Gather4 mode expects 5 coordinates" );
132+ default :
133+ return emitError (loc, " Invalid LoadMode in CpAsyncBulkTensorPrefetchOp." );
134+ }
135+ return success ();
136+ }
137+
101138LogicalResult CpAsyncBulkTensorPrefetchOp::verify () {
102- size_t numIm2ColOffsets = getIm2colOffsets ().size ();
103- bool isIm2Col = numIm2ColOffsets > 0 ;
104- return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
105- numIm2ColOffsets, getLoc ());
139+ return verifyTMALoadParams (getCoordinates ().size (), getIm2colOffsets ().size (),
140+ getMode (), getLoc ());
106141}
107142
108143LogicalResult CpAsyncBulkTensorReduceOp::verify () {
@@ -1435,28 +1470,57 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
14351470 return {id, std::move (args)};
14361471}
14371472
1438- llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID (int tensorDims,
1439- bool isIm2Col) {
1440- switch (tensorDims) {
1441- case 1 :
1442- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1443- case 2 :
1444- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1445- case 3 :
1446- return isIm2Col
1447- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1448- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1449- case 4 :
1450- return isIm2Col
1451- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1452- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1453- case 5 :
1454- return isIm2Col
1455- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1456- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1457- default :
1458- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorPrefetchOp." );
1459- }
1473+ mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs (
1474+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1475+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1476+ llvm::SmallVector<llvm::Value *> args;
1477+
1478+ // Fill the Intrinsic Args
1479+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1480+
1481+ for (auto v : thisOp.getCoordinates ())
1482+ args.push_back (mt.lookupValue (v));
1483+ for (auto v : thisOp.getIm2colOffsets ())
1484+ args.push_back (mt.lookupValue (v));
1485+
1486+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1487+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1488+ llvm::Value *i64Unused =
1489+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (mt.getLLVMContext ()), 0 );
1490+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Unused);
1491+ args.push_back (builder.getInt1 (hasCacheHint));
1492+
1493+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
1494+ static constexpr llvm::Intrinsic::ID IDTable[][6 ] = {
1495+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1496+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1497+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1498+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1499+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1500+ {NI, NI, NI,
1501+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1502+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1503+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1504+ {NI, NI, NI,
1505+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1506+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1507+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1508+ {NI, NI, NI,
1509+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1510+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1511+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1512+ {NI, NI, NI, NI, NI,
1513+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1514+
1515+ static_assert (getMaxEnumValForTMALoadMode () == std::size (IDTable) - 1 ,
1516+ " TMALoadModes must match number of rows in IDTable" );
1517+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1518+ size_t dim = thisOp.getCoordinates ().size ();
1519+ llvm::Intrinsic::ID id = IDTable[mode][dim];
1520+ if (id == llvm::Intrinsic::not_intrinsic)
1521+ llvm_unreachable (" Invalid intrinsic for CpAsyncBulkTensorPrefetchOp." );
1522+
1523+ return {id, std::move (args)};
14601524}
14611525
14621526#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
0 commit comments