Skip to content

Commit 95aa70c

Browse files
authored
[MLIR][NVVM] Add support for shared::cta destination (#168056)
This patch adds support for shared::cta as destination space in the TMA non-tensor copy Op (from global to shared::cta). * Appropriate verifier checks are added. * Unit tests are added to verify the lowering. The related intrinsic changes were merged through PR #167508. Signed-off-by: Durgadoss R <[email protected]>
1 parent 54fdf67 commit 95aa70c

File tree

5 files changed

+49
-12
lines changed

5 files changed

+49
-12
lines changed

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ end subroutine
477477
! CHECK: %[[DST_7:.*]] = llvm.addrspacecast %[[DST_PTR]] : !llvm.ptr to !llvm.ptr<7>
478478
! CHECK: %[[SRC_PTR:.*]] = fir.convert %[[SRC]] : (!fir.ref<f64>) -> !llvm.ptr
479479
! CHECK: %[[SRC_3:.*]] = llvm.addrspacecast %[[SRC_PTR]] : !llvm.ptr to !llvm.ptr<1>
480-
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : <7>, <1>
480+
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : !llvm.ptr<7>, <1>
481481

482482
attributes(global) subroutine test_bulk_s2g(a)
483483
real(8), device :: a(*)

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,16 +3342,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
33423342

33433343
def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
33443344
NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
3345-
let summary = "Async bulk copy from global memory to Shared cluster memory";
3345+
let summary = "Async bulk copy from global to Shared {cta or cluster} memory";
33463346
let description = [{
3347-
Initiates an asynchronous copy operation from global memory to cluster's
3348-
shared memory.
3347+
Initiates an asynchronous copy operation from global memory to shared
3348+
memory or shared_cluster memory.
33493349

3350-
The `multicastMask` operand is optional. When it is present, the Op copies
3350+
The `multicastMask` operand is optional and can be used only when the
3351+
destination is shared::cluster memory. When it is present, this Op copies
33513352
data from global memory to shared memory of multiple CTAs in the cluster.
33523353
Operand `multicastMask` specifies the destination CTAs in the cluster such
33533354
that each bit position in the 16-bit `multicastMask` operand corresponds to
3354-
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
3355+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
33553356

33563357
The `l2CacheHint` operand is optional, and it is used to specify cache
33573358
eviction policy that may be used during the memory access.
@@ -3360,7 +3361,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
33603361
}];
33613362

33623363
let arguments = (ins
3363-
LLVM_PointerSharedCluster:$dstMem,
3364+
AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
33643365
LLVM_PointerGlobal:$srcMem,
33653366
LLVM_PointerShared:$mbar,
33663367
I32:$size,
@@ -3374,6 +3375,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
33743375
attr-dict `:` type($dstMem) `,` type($srcMem)
33753376
}];
33763377

3378+
let hasVerifier = 1;
3379+
33773380
let extraClassDeclaration = [{
33783381
static mlir::NVVM::IDArgPair
33793382
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
212212
return success();
213213
}
214214

215+
LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
216+
bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
217+
if (isSharedCTA && getMulticastMask())
218+
return emitError("Multicast is not supported with shared::cta mode.");
219+
220+
return success();
221+
}
222+
215223
LogicalResult ConvertFloatToTF32Op::verify() {
216224
using RndMode = NVVM::FPRoundingMode;
217225
switch (getRnd()) {
@@ -1980,11 +1988,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
19801988
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
19811989
args.push_back(mt.lookupValue(thisOp.getSize()));
19821990

1983-
// Multicast mask, if available.
1991+
// Multicast mask for shared::cluster only, if available.
19841992
mlir::Value multicastMask = thisOp.getMulticastMask();
19851993
const bool hasMulticastMask = static_cast<bool>(multicastMask);
1986-
llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1987-
args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1994+
const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
1995+
if (!isSharedCTA) {
1996+
llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1997+
args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
1998+
: i16Unused);
1999+
}
19882000

19892001
// Cache hint, if available.
19902002
mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1993,11 +2005,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
19932005
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
19942006

19952007
// Flag arguments for multicast and cachehint.
1996-
args.push_back(builder.getInt1(hasMulticastMask));
2008+
if (!isSharedCTA)
2009+
args.push_back(builder.getInt1(hasMulticastMask));
19972010
args.push_back(builder.getInt1(hasCacheHint));
19982011

19992012
llvm::Intrinsic::ID id =
2000-
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
2013+
isSharedCTA
2014+
? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
2015+
: llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
20012016

20022017
return {id, std::move(args)};
20032018
}

mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>,
1616
llvm.return
1717
}
1818

19+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta
20+
llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) {
21+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
22+
// CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
23+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
24+
25+
nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
26+
27+
llvm.return
28+
}
29+
1930
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
2031
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
2132
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) {
4+
// expected-error @below {{Multicast is not supported with shared::cta mode.}}
5+
nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1>
6+
7+
llvm.return
8+
}

0 commit comments

Comments
 (0)