Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,13 @@ struct IntrinsicLibrary {
void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreI4(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreI8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreR2(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreR4(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreR8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreC4(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkStoreC8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genTransfer(mlir::Type,
Expand Down
114 changes: 114 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,34 @@ static constexpr IntrinsicHandler handlers[]{
&I::genTMABulkS2G,
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_c4",
&I::genTMABulkStoreC4,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_c8",
&I::genTMABulkStoreC8,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_i4",
&I::genTMABulkStoreI4,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_i8",
&I::genTMABulkStoreI8,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_r2",
&I::genTMABulkStoreR2,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_r4",
&I::genTMABulkStoreR4,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_store_r8",
&I::genTMABulkStoreR8,
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"tma_bulk_wait_group",
&I::genTMABulkWaitGroup,
{{}},
Expand Down Expand Up @@ -9430,6 +9458,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
builder.getI32IntegerAttr(0), {});
}

static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value src, mlir::Value dst, mlir::Value count,
mlir::Value eleSize) {
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
src = convertPtrToNVVMSpace(builder, loc, src,
mlir::NVVM::NVVMMemorySpace::Shared);
dst = convertPtrToNVVMSpace(builder, loc, dst,
mlir::NVVM::NVVMMemorySpace::Global);
mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
size, {}, {});
mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
"cp.async.bulk.commit_group", {});
mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
builder.getI32IntegerAttr(0), {});
}

// TMA_BULK_STORE_C4 (CUDA)
void IntrinsicLibrary::genTMABulkStoreC4(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_C8 (CUDA)
void IntrinsicLibrary::genTMABulkStoreC8(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 16);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_I4 (CUDA)
void IntrinsicLibrary::genTMABulkStoreI4(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_I8 (CUDA)
void IntrinsicLibrary::genTMABulkStoreI8(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_R2 (CUDA)
void IntrinsicLibrary::genTMABulkStoreR2(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 2);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_R4 (CUDA)
void IntrinsicLibrary::genTMABulkStoreR4(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_STORE_R8 (CUDA)
void IntrinsicLibrary::genTMABulkStoreR8(
llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
mlir::Value eleSize =
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
fir::getBase(args[2]), eleSize);
}

// TMA_BULK_WAIT_GROUP (CUDA)
void IntrinsicLibrary::genTMABulkWaitGroup(
llvm::ArrayRef<fir::ExtendedValue> args) {
Expand Down
85 changes: 76 additions & 9 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,13 @@ attributes(device) subroutine tma_bulk_wait_group()
end subroutine
end interface

! --------------------
! Bulk load functions
! --------------------

! Generic load, count is in bytes
! -------------------------------

interface
attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
!dir$ ignore_tkr src, dst
Expand All @@ -2058,17 +2064,9 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
end subroutine
end interface

interface
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
!dir$ ignore_tkr src, dst
integer(4), shared :: src(*)
integer(4), device :: dst(*)
integer(4), value :: nbytes
end subroutine
end interface

! Load specific types, count is in elements
! -----------------------------------------

interface tma_bulk_load
attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
Expand Down Expand Up @@ -2127,6 +2125,75 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
end subroutine
end interface

! --------------------
! Bulk Store functions
! --------------------

! Generic store, count is in bytes
! --------------------------------

interface
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
!dir$ ignore_tkr src, dst
integer(4), shared :: src(*)
integer(4), device :: dst(*)
integer(4), value :: nbytes
end subroutine
end interface

! Load specific types, count is in elements
! -----------------------------------------

interface tma_bulk_store
attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
complex(4), shared :: src(*)
complex(4), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
complex(8), shared :: src(*)
complex(8), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
integer(4), shared :: src(*)
integer(4), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
integer(8), shared :: src(*)
integer(8), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
real(2), shared :: src(*)
real(2), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
real(4), shared :: src(*)
real(4), device :: dst(*)
integer(4), value :: nelems
end subroutine

attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
real(8), shared :: src(*)
real(8), device :: dst(*)
integer(4), value :: nelems
end subroutine
end interface

contains

Expand Down
92 changes: 92 additions & 0 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,95 @@ end subroutine
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)

attributes(global) subroutine test_tma_bulk_store_c4(c, n)
integer, value :: n
complex(4), device :: c(n)
complex(4), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_tma_bulk_store_c8(c, n)
integer, value :: n
complex(8), device :: c(n)
complex(8), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_tma_bulk_store_i4(c, n)
integer, value :: n
integer(4), device :: c(n)
integer(4), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_tma_bulk_store_i8(c, n)
integer, value :: n
integer(8), device :: c(n)
integer(8), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0


attributes(global) subroutine test_tma_bulk_store_r2(c, n)
integer, value :: n
real(2), device :: c(n)
real(2), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_tma_bulk_store_r4(c, n)
integer, value :: n
real(4), device :: c(n)
real(4), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_tma_bulk_store_r8(c, n)
integer, value :: n
real(8), device :: c(n)
real(8), shared :: tmpa(1024)
integer(4) :: j, elem_count
call tma_bulk_store(tmpa, c(j), elem_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
! CHECK: nvvm.cp.async.bulk.wait_group 0