Skip to content

Commit 886e797

Browse files
authored
[flang][cuda] Add interfaces and lowering for tma_bulk_[commit|wait]_group subroutine (#163012)
https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
1 parent 56c6049 commit 886e797

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ struct IntrinsicLibrary {
456456
mlir::Value genTand(mlir::Type, llvm::ArrayRef<mlir::Value>);
457457
mlir::Value genTanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
458458
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
459+
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
460+
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
459461
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
460462
fir::ExtendedValue genTransfer(mlir::Type,
461463
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,14 @@ static constexpr IntrinsicHandler handlers[]{
10121012
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
10131013
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
10141014
{"time", &I::genTime, {}, /*isElemental=*/false},
1015+
{"tma_bulk_commit_group",
1016+
&I::genTMABulkCommitGroup,
1017+
{{}},
1018+
/*isElemental=*/false},
1019+
{"tma_bulk_wait_group",
1020+
&I::genTMABulkWaitGroup,
1021+
{{}},
1022+
/*isElemental=*/false},
10151023
{"trailz", &I::genTrailz},
10161024
{"transfer",
10171025
&I::genTransfer,
@@ -9169,6 +9177,21 @@ mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType,
91699177
fir::runtime::genTime(builder, loc));
91709178
}
91719179

9180+
// TMA_BULK_COMMIT_GROUP (CUDA)
9181+
void IntrinsicLibrary::genTMABulkCommitGroup(
9182+
llvm::ArrayRef<fir::ExtendedValue> args) {
9183+
assert(args.size() == 0);
9184+
mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
9185+
}
9186+
9187+
// TMA_BULK_WAIT_GROUP (CUDA)
9188+
void IntrinsicLibrary::genTMABulkWaitGroup(
9189+
llvm::ArrayRef<fir::ExtendedValue> args) {
9190+
assert(args.size() == 0);
9191+
auto group = builder.getIntegerAttr(builder.getI32Type(), 0);
9192+
mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {});
9193+
}
9194+
91729195
// TRIM
91739196
fir::ExtendedValue
91749197
IntrinsicLibrary::genTrim(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,16 @@ attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
20082008
end function
20092009
end interface
20102010

2011+
interface
2012+
attributes(device) subroutine tma_bulk_commit_group()
2013+
end subroutine
2014+
end interface
2015+
2016+
interface
2017+
attributes(device) subroutine tma_bulk_wait_group()
2018+
end subroutine
2019+
end interface
2020+
20112021
contains
20122022

20132023
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,12 @@ end subroutine
419419
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
420420
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
421421
! CHECK: nvvm.mbarrier.arrive.expect_tx %[[SHARED_PTR]], %{{.*}} : !llvm.ptr<3>, i32
422+
423+
attributes(global) subroutine test_tma()
424+
call tma_bulk_commit_group()
425+
call tma_bulk_wait_group()
426+
end subroutine
427+
428+
! CHECK-LABEL: func.func @_QPtest_tma()
429+
! CHECK: nvvm.cp.async.bulk.commit.group
430+
! CHECK: nvvm.cp.async.bulk.wait_group 0

0 commit comments

Comments
 (0)