Skip to content

Commit fa366b4

Browse files
authored
[MLIR][NVVM] Update TMA Load Op (#156347)
This patch includes im2col and gather mode support for the TMA Load Op. The lowering is also updated to intrinsics except when a Predicate is given. This completes the Blackwell additions on this Op. * NVVM Dialect has support for Shared::Cluster address-space now. So, this patch also updates the Op to use AS(7) instead of AS(3). The corresponding inline-ptx based unit tests are also updated. * lit tests are added for all combinations. Signed-off-by: Durgadoss R <[email protected]>
1 parent f354ca2 commit fa366b4

File tree

12 files changed

+1025
-119
lines changed

12 files changed

+1025
-119
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,26 +2827,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
28272827
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
28282828
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
28292829
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
2830-
Arguments<(ins LLVM_PointerShared:$dstMem,
2831-
LLVM_AnyPointer:$tmaDescriptor,
2830+
Arguments<(ins AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
2831+
LLVM_PointerGeneric:$tmaDescriptor,
28322832
Variadic<I32>:$coordinates,
28332833
LLVM_PointerShared:$mbar,
28342834
Variadic<I16>:$im2colOffsets,
28352835
Optional<I16>:$multicastMask,
28362836
Optional<I64>:$l2CacheHint,
2837+
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
2838+
DefaultValuedAttr<BoolAttr, "false">:$isCTAOnly,
2839+
OptionalAttr<CTAGroupKindAttr>:$group,
28372840
PtxPredicate:$predicate)> {
28382841
let description = [{
28392842
Initiates an asynchronous copy operation on the tensor data from global
2840-
memory to shared memory.
2841-
2842-
The Op operates has two load modes:
2843-
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
2844-
layout is preserved at the destination.
2845-
2846-
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
2847-
the elements in the Bounding Box of the source tensor are rearranged into
2848-
columns at the destination. In this mode, the tensor has to be at least
2849-
3-dimensional.
2843+
memory to shared::cluster (or) shared::cta memory. This Op supports all
2844+
the load modes specified in `TMALoadMode`.
28502845

28512846
The `multicastMask` operand is optional. When it is present, the Op copies
28522847
data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2857,6 +2852,10 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
28572852
The `l2CacheHint` operand is optional, and it is used to specify cache
28582853
eviction policy that may be used during the memory access.
28592854

2855+
When the `isCTAOnly` attribute is set to true, the destination is
2856+
shared::cta only. Hence, `multicastMask` and `CTAGroup` are not applicable
2857+
when `isCTAOnly` is true.
2858+
28602859
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
28612860
}];
28622861

@@ -2904,6 +2903,23 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
29042903
}
29052904
}];
29062905
let hasVerifier = 1;
2906+
2907+
let extraClassDeclaration = [{
2908+
bool hasIntrinsic() { return !getPredicate(); }
2909+
2910+
bool getAsmValues(RewriterBase &rewriter,
2911+
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
2912+
2913+
static mlir::NVVM::IDArgPair
2914+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2915+
llvm::IRBuilderBase& builder);
2916+
}];
2917+
2918+
string llvmBuilder = [{
2919+
auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2920+
*op, moduleTranslation, builder);
2921+
createIntrinsicCall(builder, id, args);
2922+
}];
29072923
}
29082924

29092925
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,14 @@ struct NVGPUTmaAsyncLoadOpLowering
993993
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
994994
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
995995
adaptor.getDst(), {});
996+
// Intrinsics takes a shared-cluster pointer so we need an
997+
// address space cast from 3 to 7.
998+
// TODO: Introduce AS(7) in NVGPU.
999+
auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
1000+
op->getContext(),
1001+
static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
1002+
dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
1003+
9961004
Value barrier =
9971005
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
9981006
adaptor.getMbarId(), rewriter);
@@ -1001,9 +1009,14 @@ struct NVGPUTmaAsyncLoadOpLowering
10011009
for (auto [index, value] : llvm::enumerate(coords)) {
10021010
coords[index] = truncToI32(b, value);
10031011
}
1012+
1013+
// TODO: Enhance the NVGPU Op for other modes too
10041014
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
10051015
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
10061016
ValueRange{}, adaptor.getMulticastMask(), Value{},
1017+
NVVM::TMALoadMode::TILE, // default is TILE mode
1018+
false, // default is cluster-scope
1019+
nullptr, // default is no cta-group
10071020
adaptor.getPredicate());
10081021
return success();
10091022
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ using namespace NVVM;
4545
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
4646
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
4747

48+
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49+
4850
//===----------------------------------------------------------------------===//
4951
// Verifier methods
5052
//===----------------------------------------------------------------------===//
5153

5254
// This verifier is shared among the following Ops:
53-
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
55+
// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
5456
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
5557
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
5658
bool isIm2Col,
@@ -74,13 +76,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
7476
return success();
7577
}
7678

77-
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
78-
size_t numIm2ColOffsets = getIm2colOffsets().size();
79-
bool isIm2Col = numIm2ColOffsets > 0;
80-
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
81-
numIm2ColOffsets, getLoc());
82-
}
83-
8479
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
8580
TMAStoreMode mode = getMode();
8681
// We lower through inline-ptx when getPredicate() is true.
@@ -158,6 +153,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
158153
getMode(), getLoc());
159154
}
160155

156+
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
157+
TMALoadMode mode = getMode();
158+
bool isCTAOnly = getIsCTAOnly();
159+
if (getPredicate()) { // Inline-asm based lowering
160+
if (isCTAOnly)
161+
return emitError("Predicate is supported only for shared::cluster mode.");
162+
if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163+
return emitError(
164+
"Predicate is supported only for Tile and Im2col modes.");
165+
} else { // Intrinsics-based lowering
166+
NVVMMemorySpace expectedAS =
167+
isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168+
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
169+
.getAddressSpace();
170+
if (AS != expectedAS)
171+
return emitError()
172+
<< (isCTAOnly
173+
? "Shared::cta destination requires address-space 3."
174+
: "Shared::cluster destination requires address-space 7.");
175+
// Checks specific to shared::cta mode
176+
if (isCTAOnly) {
177+
if (getMulticastMask())
178+
return emitError("Multicast is not supported with shared::cta mode.");
179+
if (getGroup())
180+
return emitError("CTAGroup is not supported with shared::cta mode.");
181+
}
182+
}
183+
184+
return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
185+
getMode(), getLoc());
186+
}
187+
161188
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
162189
TMAStoreMode mode = getMode();
163190
size_t dims = getCoordinates().size();
@@ -1553,6 +1580,130 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
15531580
return {id, std::move(args)};
15541581
}
15551582

1583+
bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1584+
RewriterBase &rewriter,
1585+
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1586+
&asmValues) {
1587+
// Add all the operands but not the attrs to the asmValues list.
1588+
// The attrs here are used to generate the right variants for
1589+
// intrinsics-lowering. So, we ignore them while generating inline-PTX.
1590+
for (auto val : getOperands())
1591+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
1592+
1593+
return false;
1594+
}
1595+
1596+
mlir::NVVM::IDArgPair
1597+
CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1598+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1599+
auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1600+
const bool isCTAOnly = thisOp.getIsCTAOnly();
1601+
llvm::SmallVector<llvm::Value *> args;
1602+
1603+
// Fill the Intrinsic Args
1604+
args.push_back(mt.lookupValue(thisOp.getDstMem()));
1605+
args.push_back(mt.lookupValue(thisOp.getMbar()));
1606+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1607+
1608+
// Coordinates and im2col-offsets
1609+
for (mlir::Value v : thisOp.getCoordinates())
1610+
args.push_back(mt.lookupValue(v));
1611+
for (mlir::Value v : thisOp.getIm2colOffsets())
1612+
args.push_back(mt.lookupValue(v));
1613+
1614+
// MulticastMask, if available
1615+
mlir::Value mcMask = thisOp.getMulticastMask();
1616+
const bool hasMC = static_cast<bool>(mcMask);
1617+
llvm::Value *i16Zero =
1618+
llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
1619+
1620+
// CacheHint, if available
1621+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1622+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1623+
llvm::Value *i64Zero =
1624+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1625+
1626+
// Flag argument CTAGroup
1627+
// CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1628+
// Hence, the +1 to getGroup().
1629+
const int32_t val =
1630+
thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
1631+
llvm::Value *cg =
1632+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
1633+
1634+
if (!isCTAOnly) {
1635+
// For shared::cluster, all the arguments that we build are applicable.
1636+
args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1637+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1638+
args.push_back(builder.getInt1(hasMC));
1639+
args.push_back(builder.getInt1(hasCacheHint));
1640+
args.push_back(cg);
1641+
} else {
1642+
// For shared::cta, only cache-hint is applicable.
1643+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1644+
args.push_back(builder.getInt1(hasCacheHint));
1645+
}
1646+
1647+
constexpr size_t numDims = 5; // 1D to 5D
1648+
constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
1649+
using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1650+
using TableTy = std::array<rowTy, numModes>;
1651+
static constexpr TableTy IDTable{
1652+
{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1653+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1654+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1655+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1656+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1657+
{notIntrinsic, notIntrinsic, notIntrinsic,
1658+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1659+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1660+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1661+
{notIntrinsic, notIntrinsic, notIntrinsic,
1662+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1663+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1664+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1665+
{notIntrinsic, notIntrinsic, notIntrinsic,
1666+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1667+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1668+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1669+
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1670+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1671+
1672+
static constexpr TableTy IDTableCTA{
1673+
{{notIntrinsic,
1674+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1675+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1676+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1677+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1678+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1679+
{notIntrinsic, notIntrinsic, notIntrinsic,
1680+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1681+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1682+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1683+
{notIntrinsic, notIntrinsic, notIntrinsic,
1684+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1685+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1686+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1687+
{notIntrinsic, notIntrinsic, notIntrinsic,
1688+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1689+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1690+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1691+
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1692+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1693+
1694+
static_assert(
1695+
(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1696+
(getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1697+
"TMALoadModes must match number of rows in IDTable and IDTableCTA");
1698+
size_t mode = static_cast<size_t>(thisOp.getMode());
1699+
size_t dim = thisOp.getCoordinates().size();
1700+
auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1701+
assert(id != notIntrinsic &&
1702+
"Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1703+
1704+
return {id, std::move(args)};
1705+
}
1706+
15561707
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
15571708
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
15581709
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,8 @@ module @mymodule {
854854
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
855855
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
856856
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
857-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
857+
// CHECK: %[[dest:.+]] = llvm.addrspacecast %[[shmemOfset]] : !llvm.ptr<3> to !llvm.ptr<7>
858+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[dest]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
858859
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
859860
return
860861
}

0 commit comments

Comments
 (0)