Skip to content

Commit 36dc614

Browse files
authored
[MLIR][NVVM] Update TMA tensor prefetch Op (#153464)
This patch updates the TMA Tensor prefetch Op to add support for im2col_w/w128 and tile_gather4 modes. This completes support for all modes available in Blackwell. * lit tests are added for all possible combinations. * The invalid tests are moved to a separate file with more coverage. Signed-off-by: Durgadoss R <[email protected]>
1 parent 5050da7 commit 36dc614

File tree

5 files changed

+292
-117
lines changed

5 files changed

+292
-117
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,6 +2302,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
23022302
// NVVM TMA Ops
23032303
//===----------------------------------------------------------------------===//
23042304

2305+
// List of modes supported for TMA Load and Prefetch Ops
2306+
def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
2307+
def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2308+
def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
2309+
def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
2310+
def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
2311+
2312+
def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
2313+
[TMALoadModeTile, TMALoadModeIm2Col,
2314+
TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
2315+
TMALoadModeTileGather4]> {
2316+
let genSpecializedAttr = 0;
2317+
let cppNamespace = "::mlir::NVVM";
2318+
}
2319+
def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
2320+
let summary = "List of Load-Modes supported for TMA Tensor Ops";
2321+
let description = [{
2322+
TMA Tensor Ops support the following modes, when copying data from
2323+
global memory to shared memory (i.e. load):
2324+
2325+
Tile Mode: It's the default mode. The source multi-dimensional tensor
2326+
layout is preserved at the destination.
2327+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
2328+
2329+
Im2col Mode: This mode is used when `im2colOffsets` operands are present.
2330+
The elements in the Bounding Box of the source tensor are rearranged into
2331+
columns at the destination. In this mode, the tensor has to be at least
2332+
3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
2333+
is the dimension of the tensor.
2334+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
2335+
2336+
Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
2337+
elements are accessed across the W dimension only. The number of `im2colOffsets`
2338+
are always two, referred as `wHalo` and `wOffset`.
2339+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
2340+
2341+
Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
2342+
elements accessed across the W dimension is always 128 only.
2343+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
2344+
2345+
Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
2346+
In gather4 mode, four rows in the source 2D tensor are combined to form a single
2347+
2D tensor at the destination. This mode requires five co-ordinates. The first one
2348+
represents the column-index followed by four row indices.
2349+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
2350+
}];
2351+
2352+
let assemblyFormat = "`<` $value `>`";
2353+
}
2354+
23052355
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
23062356
Arguments<(ins )> {
23072357
let assemblyFormat = "attr-dict";
@@ -2570,23 +2620,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
25702620
def NVVM_CpAsyncBulkTensorPrefetchOp :
25712621
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
25722622
let arguments = (ins
2573-
LLVM_AnyPointer:$tmaDescriptor,
2623+
LLVM_PointerGeneric:$tmaDescriptor,
25742624
Variadic<I32>:$coordinates,
25752625
Variadic<I16>:$im2colOffsets,
2626+
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
25762627
Optional<I64>:$l2CacheHint);
25772628

25782629
let description = [{
25792630
Initiates an asynchronous prefetch operation on the tensor data from global
2580-
memory to L2 cache.
2581-
2582-
The Op has two modes:
2583-
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
2584-
layout is preserved at the destination.
2585-
2586-
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
2587-
the elements in the Bounding Box of the source tensor are rearranged into
2588-
columns at the destination. In this mode, the tensor has to be at least
2589-
3-dimensional.
2631+
memory to L2 cache. This Op supports all the load modes specified in
2632+
`TMALoadMode`.
25902633

25912634
The `l2CacheHint` operand is optional, and it is used to specify cache
25922635
eviction policy that may be used during the memory access.
@@ -2603,34 +2646,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
26032646
}];
26042647

26052648
let extraClassDeclaration = [{
2606-
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
2649+
static mlir::NVVM::IDArgPair
2650+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2651+
llvm::IRBuilderBase& builder);
26072652
}];
26082653

26092654
let hasVerifier = 1;
26102655

26112656
string llvmBuilder = [{
2612-
// Arguments to the intrinsic:
2613-
// tmaDesc, tensorDims, im2colOffsets
2614-
// cache_hint(if applicable) and flag(boolean)
2615-
llvm::SmallVector<llvm::Value *> translatedOperands;
2616-
translatedOperands.push_back($tmaDescriptor);
2617-
2618-
for (auto v : op.getCoordinates())
2619-
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2620-
2621-
for (auto v : op.getIm2colOffsets())
2622-
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2623-
2624-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2625-
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2626-
2627-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2628-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2629-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2630-
2631-
auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
2632-
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
2633-
createIntrinsicCall(builder, intId, translatedOperands);
2657+
auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
2658+
*op, moduleTranslation, builder);
2659+
createIntrinsicCall(builder, id, args);
26342660
}];
26352661
}
26362662

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ using namespace NVVM;
5050

5151
// This verifier is shared among the following Ops:
5252
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
53-
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
5453
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
5554
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
5655
bool isIm2Col,
@@ -98,11 +97,47 @@ LogicalResult CpAsyncOp::verify() {
9897
return success();
9998
}
10099

100+
// This verify params can be shared across TMA Load and Prefetch Ops.
101+
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
102+
TMALoadMode mode, Location loc) {
103+
if (tensorDims < 1 || tensorDims > 5)
104+
return emitError(loc, "expects coordinates between 1 to 5 dimension");
105+
106+
auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
107+
size_t expectedIm2colOff) -> LogicalResult {
108+
if (isIm2col && (tensorDims < 3))
109+
return emitError(loc)
110+
<< "to use " << stringifyEnum(mode)
111+
<< " mode, the tensor has to be at least 3-dimensional";
112+
113+
if (numIm2colOff != expectedIm2colOff)
114+
return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
115+
<< " (provided " << numIm2colOff << ")";
116+
117+
return success();
118+
};
119+
120+
switch (mode) {
121+
case TMALoadMode::TILE:
122+
return checkTMALoadParams(mode, false, 0);
123+
case TMALoadMode::IM2COL:
124+
return checkTMALoadParams(mode, true, tensorDims - 2);
125+
case TMALoadMode::IM2COL_W:
126+
case TMALoadMode::IM2COL_W_128:
127+
return checkTMALoadParams(mode, true, 2);
128+
case TMALoadMode::TILE_GATHER4:
129+
return (tensorDims == 5)
130+
? checkTMALoadParams(mode, false, 0)
131+
: emitError(loc, "Gather4 mode expects 5 coordinates");
132+
default:
133+
return emitError(loc, "Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
134+
}
135+
return success();
136+
}
137+
101138
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
102-
size_t numIm2ColOffsets = getIm2colOffsets().size();
103-
bool isIm2Col = numIm2ColOffsets > 0;
104-
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
105-
numIm2ColOffsets, getLoc());
139+
return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
140+
getMode(), getLoc());
106141
}
107142

108143
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
@@ -1435,28 +1470,57 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
14351470
return {id, std::move(args)};
14361471
}
14371472

1438-
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1439-
bool isIm2Col) {
1440-
switch (tensorDims) {
1441-
case 1:
1442-
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1443-
case 2:
1444-
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1445-
case 3:
1446-
return isIm2Col
1447-
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1448-
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1449-
case 4:
1450-
return isIm2Col
1451-
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1452-
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1453-
case 5:
1454-
return isIm2Col
1455-
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1456-
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1457-
default:
1458-
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1459-
}
1473+
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
1474+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1475+
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1476+
llvm::SmallVector<llvm::Value *> args;
1477+
1478+
// Fill the Intrinsic Args
1479+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1480+
1481+
for (auto v : thisOp.getCoordinates())
1482+
args.push_back(mt.lookupValue(v));
1483+
for (auto v : thisOp.getIm2colOffsets())
1484+
args.push_back(mt.lookupValue(v));
1485+
1486+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1487+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1488+
llvm::Value *i64Unused =
1489+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1490+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1491+
args.push_back(builder.getInt1(hasCacheHint));
1492+
1493+
const unsigned NI = llvm::Intrinsic::not_intrinsic;
1494+
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1495+
{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1496+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1497+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1498+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1499+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1500+
{NI, NI, NI,
1501+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1502+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1503+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1504+
{NI, NI, NI,
1505+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1506+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1507+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1508+
{NI, NI, NI,
1509+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1510+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1511+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1512+
{NI, NI, NI, NI, NI,
1513+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1514+
1515+
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1516+
"TMALoadModes must match number of rows in IDTable");
1517+
size_t mode = static_cast<size_t>(thisOp.getMode());
1518+
size_t dim = thisOp.getCoordinates().size();
1519+
llvm::Intrinsic::ID id = IDTable[mode][dim];
1520+
if (id == llvm::Intrinsic::not_intrinsic)
1521+
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1522+
1523+
return {id, std::move(args)};
14601524
}
14611525

14621526
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \

0 commit comments

Comments
 (0)