Skip to content

Commit dbbd454

Browse files
committed
[flang][cuda] Add interface and lowering for tma_bulk_g2s
1 parent a611074 commit dbbd454

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ struct IntrinsicLibrary {
458458
mlir::Value genTanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
459459
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
460460
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
461+
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
461462
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
462463
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
463464
fir::ExtendedValue genTransfer(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,13 @@ static constexpr IntrinsicHandler handlers[]{
10201020
&I::genTMABulkCommitGroup,
10211021
{{}},
10221022
/*isElemental=*/false},
1023+
{"tma_bulk_g2s",
1024+
&I::genTMABulkG2S,
1025+
{{{"barrier", asAddr},
1026+
{"src", asAddr},
1027+
{"dst", asAddr},
1028+
{"nbytes", asValue}}},
1029+
/*isElemental=*/false},
10231030
{"tma_bulk_wait_group",
10241031
&I::genTMABulkWaitGroup,
10251032
{{}},
@@ -3200,17 +3207,17 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
32003207
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
32013208
}
32023209

3203-
static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder,
3204-
mlir::Location loc,
3205-
mlir::Value barrier) {
3210+
static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder,
3211+
mlir::Location loc,
3212+
mlir::Value barrier,
3213+
mlir::NVVM::NVVMMemorySpace space) {
32063214
mlir::Value llvmPtr = fir::ConvertOp::create(
32073215
builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
32083216
barrier);
32093217
mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
32103218
builder, loc,
3211-
mlir::LLVM::LLVMPointerType::get(
3212-
builder.getContext(),
3213-
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
3219+
mlir::LLVM::LLVMPointerType::get(builder.getContext(),
3220+
static_cast<unsigned>(space)),
32143221
llvmPtr);
32153222
return addrCast;
32163223
}
@@ -3220,7 +3227,8 @@ mlir::Value
32203227
IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
32213228
llvm::ArrayRef<mlir::Value> args) {
32223229
assert(args.size() == 1);
3223-
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
3230+
mlir::Value barrier = convertPtrToNVVMSpace(
3231+
builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
32243232
return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
32253233
barrier)
32263234
.getResult();
@@ -3231,7 +3239,8 @@ mlir::Value
32313239
IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
32323240
llvm::ArrayRef<mlir::Value> args) {
32333241
assert(args.size() == 2);
3234-
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
3242+
mlir::Value barrier = convertPtrToNVVMSpace(
3243+
builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
32353244
mlir::Value token = fir::AllocaOp::create(builder, loc, resultType);
32363245
// TODO: the MBarrierArriveExpectTxOp is not taking the state argument and
32373246
// currently just the sink symbol `_`.
@@ -3244,8 +3253,8 @@ IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
32443253
// BARRIER_INIT (CUDA)
32453254
void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
32463255
assert(args.size() == 2);
3247-
mlir::Value barrier =
3248-
convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
3256+
mlir::Value barrier = convertPtrToNVVMSpace(
3257+
builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
32493258
mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
32503259
fir::getBase(args[1]), {});
32513260
auto kind = mlir::NVVM::ProxyKindAttr::get(
@@ -9204,6 +9213,20 @@ void IntrinsicLibrary::genTMABulkCommitGroup(
92049213
mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
92059214
}
92069215

9216+
// TMA_BULK_G2S (CUDA)
9217+
void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
9218+
assert(args.size() == 4);
9219+
mlir::Value barrier = convertPtrToNVVMSpace(
9220+
builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
9221+
mlir::Value dst =
9222+
convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]),
9223+
mlir::NVVM::NVVMMemorySpace::SharedCluster);
9224+
mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
9225+
mlir::NVVM::NVVMMemorySpace::Global);
9226+
mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create(
9227+
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
9228+
}
9229+
92079230
// TMA_BULK_WAIT_GROUP (CUDA)
92089231
void IntrinsicLibrary::genTMABulkWaitGroup(
92099232
llvm::ArrayRef<fir::ExtendedValue> args) {

flang/module/cudadevice.f90

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,17 @@ attributes(device) subroutine tma_bulk_wait_group()
20232023
end subroutine
20242024
end interface
20252025

2026+
! Generic load, count is in bytes
2027+
interface
2028+
attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
2029+
!dir$ ignore_tkr src, dst
2030+
integer(8), shared :: barrier
2031+
integer(4), device :: src(*)
2032+
integer(4), shared :: dst(*)
2033+
integer(4), value :: nbytes
2034+
end subroutine
2035+
end interface
2036+
20262037
contains
20272038

20282039
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,14 @@ end subroutine
437437
! CHECK-LABEL: func.func @_QPtest_tma()
438438
! CHECK: nvvm.cp.async.bulk.commit.group
439439
! CHECK: nvvm.cp.async.bulk.wait_group 0
440+
441+
attributes(global) subroutine test_bulk_g2s(c, a, b, n)
442+
real(8), device :: a(*)
443+
real(8), shared :: tmpa(1024)
444+
integer(8), shared :: barrier1
445+
integer(4) :: tx_count
446+
call tma_bulk_g2s(barrier1, a(j), tmpa, tx_count)
447+
end subroutine
448+
449+
! CHECK-LABEL: func.func @_QPtest_bulk_g2s
450+
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1>

0 commit comments

Comments
 (0)