Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ struct IntrinsicLibrary {
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genTransfer(mlir::Type,
Expand Down
15 changes: 15 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ static constexpr IntrinsicHandler handlers[]{
{"dst", asAddr},
{"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_s2g",
&I::genTMABulkS2G,
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_wait_group",
&I::genTMABulkWaitGroup,
{{}},
Expand Down Expand Up @@ -9227,6 +9231,17 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
}

// TMA_BULK_S2G (CUDA)
void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]),
mlir::NVVM::NVVMMemorySpace::Shared);
mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
mlir::NVVM::NVVMMemorySpace::Global);
mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(
builder, loc, dst, src, fir::getBase(args[2]), {}, {});
}

// TMA_BULK_WAIT_GROUP (CUDA)
void IntrinsicLibrary::genTMABulkWaitGroup(
llvm::ArrayRef<fir::ExtendedValue> args) {
Expand Down
9 changes: 9 additions & 0 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,15 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
end subroutine
end interface

interface
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
!dir$ ignore_tkr src, dst
integer(4), shared :: src(*)
integer(4), device :: dst(*)
integer(4), value :: nbytes
end subroutine
end interface

contains

attributes(device) subroutine syncthreads()
Expand Down
10 changes: 10 additions & 0 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,13 @@ end subroutine

! CHECK-LABEL: func.func @_QPtest_bulk_g2s
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1>

attributes(global) subroutine test_bulk_s2g(c, a, b, n)
real(8), device :: a(*)
real(8), shared :: tmpa(1024)
integer(4) :: tx_count
call tma_bulk_s2g(tmpa, a(j), tx_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_bulk_s2g
! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>