Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ struct IntrinsicLibrary {
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genTransfer(mlir::Type,
Expand Down
15 changes: 15 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ static constexpr IntrinsicHandler handlers[]{
{"dst", asAddr},
{"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_s2g",
&I::genTMABulkS2G,
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_wait_group",
&I::genTMABulkWaitGroup,
{{}},
Expand Down Expand Up @@ -9227,6 +9231,17 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
}

// TMA_BULK_S2G (CUDA)
void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]),
mlir::NVVM::NVVMMemorySpace::Shared);
mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
mlir::NVVM::NVVMMemorySpace::Global);
mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(
builder, loc, dst, src, fir::getBase(args[2]), {}, {});
}

// TMA_BULK_WAIT_GROUP (CUDA)
void IntrinsicLibrary::genTMABulkWaitGroup(
llvm::ArrayRef<fir::ExtendedValue> args) {
Expand Down
9 changes: 9 additions & 0 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2034,6 +2034,15 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
end subroutine
end interface

interface
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
!dir$ ignore_tkr src, dst
integer(4), shared :: src(*)
integer(4), device :: dst(*)
integer(4), value :: nbytes
end subroutine
end interface

contains

attributes(device) subroutine syncthreads()
Expand Down
12 changes: 11 additions & 1 deletion flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ end subroutine
! CHECK: nvvm.cp.async.bulk.commit.group
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_bulk_g2s(c, a, b, n)
attributes(global) subroutine test_bulk_g2s(a)
real(8), device :: a(*)
real(8), shared :: tmpa(1024)
integer(8), shared :: barrier1
Expand All @@ -448,3 +448,13 @@ end subroutine

! CHECK-LABEL: func.func @_QPtest_bulk_g2s
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1>

attributes(global) subroutine test_bulk_s2g(a)
real(8), device :: a(*)
real(8), shared :: tmpa(1024)
integer(4) :: tx_count
call tma_bulk_s2g(tmpa, a(j), tx_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_bulk_s2g
! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>