Skip to content

Commit 4a5b051

Browse files
authored
[MLIR][NVVM] Update TMA Store Op (llvm#155435)
This patch includes im2col and scatter mode support to the TMA Store Op. The lowering is also updated to intrinsics except when Predicate is given. This completes the Blackwell additions on this Op. * lit tests are added for all combinations. * Move the TMA reduce invalid tests to their own file. Signed-off-by: Durgadoss R <[email protected]>
1 parent 562c27e commit 4a5b051

File tree

8 files changed

+301
-59
lines changed

8 files changed

+301
-59
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,6 +2353,20 @@ def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
23532353
let assemblyFormat = "`<` $value `>`";
23542354
}
23552355

2356+
// List of modes supported for TMA Store and Reduction Ops
2357+
def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
2358+
def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2359+
def TMAStoreModeTileScatter4 : I32EnumAttrCase<"TILE_SCATTER4", 2, "tile_scatter4">;
2360+
2361+
def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
2362+
[TMAStoreModeTile, TMAStoreModeIm2Col, TMAStoreModeTileScatter4]> {
2363+
let genSpecializedAttr = 0;
2364+
let cppNamespace = "::mlir::NVVM";
2365+
}
2366+
def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
2367+
let assemblyFormat = "`<` $value `>`";
2368+
}
2369+
23562370
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
23572371
Arguments<(ins )> {
23582372
let assemblyFormat = "attr-dict";
@@ -2479,20 +2493,43 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
24792493
}
24802494

24812495
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
2482-
NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
2483-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
2496+
NVVM_PTXBuilder_Op<"cp.async.bulk.tensor.global.shared.cta",
2497+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
24842498
AttrSizedOperandSegments]>,
2485-
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor,
2486-
LLVM_PointerShared:$srcMem,
2487-
Variadic<I32>:$coordinates,
2488-
PtxPredicate:$predicate)> {
2499+
Arguments<(ins LLVM_PointerGeneric:$tmaDescriptor,
2500+
LLVM_PointerShared:$srcMem,
2501+
Variadic<I32>:$coordinates,
2502+
Optional<I64>:$l2CacheHint,
2503+
DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
2504+
PtxPredicate:$predicate)> {
2505+
let description = [{
2506+
Initiates an asynchronous copy of the tensor data from shared::cta
2507+
memory to global memory. This Op supports all the store modes specified in
2508+
`TMAStoreMode`.
2509+
2510+
The `l2CacheHint` operand is optional, and it is used to specify cache
2511+
eviction policy that may be used during the memory access.
2512+
2513+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
2514+
}];
2515+
24892516
let assemblyFormat = [{
24902517
$tmaDescriptor `,`
24912518
$srcMem `,`
24922519
`box` `[`$coordinates `]`
2493-
(`,` `predicate` `=` $predicate^)?
2494-
attr-dict `:` type(operands)
2520+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2521+
(`,` `predicate` `=` $predicate^)?
2522+
attr-dict `:` type($tmaDescriptor) `,` type($srcMem)
2523+
}];
2524+
2525+
let extraClassDeclaration = [{
2526+
bool hasIntrinsic() { return !getPredicate(); }
2527+
2528+
static mlir::NVVM::IDArgPair
2529+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2530+
llvm::IRBuilderBase& builder);
24952531
}];
2532+
24962533
let extraClassDefinition = [{
24972534
std::string $cppClass::getPtx() {
24982535
int dim = getCoordinates().size();
@@ -2508,6 +2545,12 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
25082545
}
25092546
}];
25102547
let hasVerifier = 1;
2548+
2549+
string llvmBuilder = [{
2550+
auto [id, args] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2551+
*op, moduleTranslation, builder);
2552+
createIntrinsicCall(builder, id, args);
2553+
}];
25112554
}
25122555

25132556
//===----------------------------------------------------------------------===//
@@ -2661,19 +2704,6 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
26612704
}];
26622705
}
26632706

2664-
// List of modes supported for TMA Store and Reduction Ops
2665-
def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
2666-
def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2667-
2668-
def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
2669-
[TMAStoreModeTile, TMAStoreModeIm2Col]> {
2670-
let genSpecializedAttr = 0;
2671-
let cppNamespace = "::mlir::NVVM";
2672-
}
2673-
def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
2674-
let assemblyFormat = "`<` $value `>`";
2675-
}
2676-
26772707
// List of Reduction Ops supported with TMA Store
26782708
def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
26792709
def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,8 +1026,10 @@ struct NVGPUTmaAsyncStoreOpLowering
10261026
coords[index] = truncToI32(b, value);
10271027
}
10281028

1029+
// TODO: Enhance the NVGPU Op for other modes too
10291030
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1030-
op, adaptor.getTensorMapDescriptor(), dest, coords,
1031+
op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1032+
NVVM::TMAStoreMode::TILE, // default is TILE mode
10311033
adaptor.getPredicate());
10321034
return success();
10331035
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
8181
}
8282

8383
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
84-
if (getCoordinates().size() > 5)
85-
return emitError("Maximum 5 coordinates and dimension is supported.");
84+
TMAStoreMode mode = getMode();
85+
// We lower through inline-ptx when getPredicate() is true.
86+
// a) Only TILE mode is supported
87+
// b) Cache-hint is not supported
88+
if (getPredicate()) {
89+
if (mode != TMAStoreMode::TILE)
90+
return emitError("Inline-ptx lowering supported only for Tile mode.");
91+
if (getL2CacheHint())
92+
return emitError("Inline-ptx lowering unsupported with L2 cache-hint.");
93+
}
94+
95+
size_t dims = getCoordinates().size();
96+
switch (mode) {
97+
case TMAStoreMode::TILE:
98+
return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
99+
case TMAStoreMode::IM2COL:
100+
return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
101+
case TMAStoreMode::TILE_SCATTER4:
102+
if (dims != 5)
103+
return emitError("Scatter4 mode expects 5 coordinates");
104+
}
86105
return success();
87106
}
88107

@@ -139,9 +158,17 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
139158
}
140159

141160
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
142-
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
143-
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
144-
getLoc());
161+
TMAStoreMode mode = getMode();
162+
size_t dims = getCoordinates().size();
163+
switch (mode) {
164+
case TMAStoreMode::TILE:
165+
return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc());
166+
case TMAStoreMode::IM2COL:
167+
return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc());
168+
case TMAStoreMode::TILE_SCATTER4:
169+
return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
170+
}
171+
return success();
145172
}
146173

147174
LogicalResult ConvertFloatToTF32Op::verify() {
@@ -1521,6 +1548,51 @@ mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
15211548
return {id, std::move(args)};
15221549
}
15231550

1551+
mlir::NVVM::IDArgPair
1552+
CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1553+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1554+
auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1555+
llvm::SmallVector<llvm::Value *> args;
1556+
1557+
// Fill the Intrinsic Args
1558+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1559+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1560+
1561+
for (auto v : thisOp.getCoordinates())
1562+
args.push_back(mt.lookupValue(v));
1563+
1564+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1565+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1566+
llvm::Value *i64Unused =
1567+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1568+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1569+
args.push_back(builder.getInt1(hasCacheHint));
1570+
1571+
const unsigned NI = llvm::Intrinsic::not_intrinsic;
1572+
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
1573+
{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1574+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1575+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1576+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1577+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1578+
{NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1579+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1580+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1581+
{NI, NI, NI, NI, NI,
1582+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1583+
1584+
static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
1585+
"TMAStoreModes must match number of rows in IDTable");
1586+
size_t mode = static_cast<size_t>(thisOp.getMode());
1587+
size_t dim = thisOp.getCoordinates().size();
1588+
llvm::Intrinsic::ID id = IDTable[mode][dim];
1589+
if (id == llvm::Intrinsic::not_intrinsic)
1590+
llvm_unreachable(
1591+
"Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
1592+
1593+
return {id, std::move(args)};
1594+
}
1595+
15241596
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
15251597
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
15261598

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,47 +214,36 @@ func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>,
214214

215215
// CHECK-LABEL: @tma_store_1d
216216
func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
217-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
218-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
219217
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
220-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
218+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>
221219
return
222220
}
223221

224222
// CHECK-LABEL: @tma_store_2d
225223
func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
226-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
227-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
228224
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
229-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
225+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>
230226
return
231227
}
232228

233229
// CHECK-LABEL: @tma_store_3d
234230
func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
235-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
236-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
237231
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
238-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
232+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>
239233
return
240234
}
241235

242236
// CHECK-LABEL: @tma_store_4d
243237
func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
244-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
245-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
246238
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
247-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
239+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>
248240
return
249241
}
250242

251243
// CHECK-LABEL: @tma_store_5d
252244
func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
253-
// CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
254-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
255-
256245
// CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
257-
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
246+
nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>
258247
return
259248
}
260249

0 commit comments

Comments
 (0)