Skip to content

Commit 757e0bb

Browse files
committed
[flang][cuda] Add interfaces and lowering for tma_bulk_store
1 parent d0e0d7f commit 757e0bb

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,13 @@ struct IntrinsicLibrary {
462462
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
463463
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
464464
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
465+
void genTMABulkStoreI4(llvm::ArrayRef<fir::ExtendedValue>);
466+
void genTMABulkStoreI8(llvm::ArrayRef<fir::ExtendedValue>);
467+
void genTMABulkStoreR2(llvm::ArrayRef<fir::ExtendedValue>);
468+
void genTMABulkStoreR4(llvm::ArrayRef<fir::ExtendedValue>);
469+
void genTMABulkStoreR8(llvm::ArrayRef<fir::ExtendedValue>);
470+
void genTMABulkStoreC4(llvm::ArrayRef<fir::ExtendedValue>);
471+
void genTMABulkStoreC8(llvm::ArrayRef<fir::ExtendedValue>);
465472
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
466473
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
467474
fir::ExtendedValue genTransfer(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,34 @@ static constexpr IntrinsicHandler handlers[]{
10491049
&I::genTMABulkS2G,
10501050
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
10511051
/*isElemental=*/false},
1052+
{"tma_bulk_store_c4",
1053+
&I::genTMABulkStoreC4,
1054+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1055+
/*isElemental=*/false},
1056+
{"tma_bulk_store_c8",
1057+
&I::genTMABulkStoreC8,
1058+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1059+
/*isElemental=*/false},
1060+
{"tma_bulk_store_i4",
1061+
&I::genTMABulkStoreI4,
1062+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1063+
/*isElemental=*/false},
1064+
{"tma_bulk_store_i8",
1065+
&I::genTMABulkStoreI8,
1066+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1067+
/*isElemental=*/false},
1068+
{"tma_bulk_store_r2",
1069+
&I::genTMABulkStoreR2,
1070+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1071+
/*isElemental=*/false},
1072+
{"tma_bulk_store_r4",
1073+
&I::genTMABulkStoreR4,
1074+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1075+
/*isElemental=*/false},
1076+
{"tma_bulk_store_r8",
1077+
&I::genTMABulkStoreR8,
1078+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1079+
/*isElemental=*/false},
10521080
{"tma_bulk_wait_group",
10531081
&I::genTMABulkWaitGroup,
10541082
{{}},
@@ -9289,6 +9317,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
92899317
builder, loc, dst, src, fir::getBase(args[2]), {}, {});
92909318
}
92919319

9320+
static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
9321+
mlir::Value src, mlir::Value dst, mlir::Value count,
9322+
mlir::Value eleSize) {
9323+
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
9324+
src = convertPtrToNVVMSpace(builder, loc, src,
9325+
mlir::NVVM::NVVMMemorySpace::Shared);
9326+
dst = convertPtrToNVVMSpace(builder, loc, dst,
9327+
mlir::NVVM::NVVMMemorySpace::Global);
9328+
mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
9329+
size, {}, {});
9330+
mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
9331+
"cp.async.bulk.commit_group", {});
9332+
mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
9333+
builder.getI32IntegerAttr(0), {});
9334+
}
9335+
9336+
// TMA_BULK_STORE_C4 (CUDA)
9337+
void IntrinsicLibrary::genTMABulkStoreC4(
9338+
llvm::ArrayRef<fir::ExtendedValue> args) {
9339+
assert(args.size() == 3);
9340+
mlir::Value eleSize =
9341+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9342+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9343+
fir::getBase(args[2]), eleSize);
9344+
}
9345+
9346+
// TMA_BULK_STORE_C8 (CUDA)
9347+
void IntrinsicLibrary::genTMABulkStoreC8(
9348+
llvm::ArrayRef<fir::ExtendedValue> args) {
9349+
assert(args.size() == 3);
9350+
mlir::Value eleSize =
9351+
builder.createIntegerConstant(loc, builder.getI32Type(), 16);
9352+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9353+
fir::getBase(args[2]), eleSize);
9354+
}
9355+
9356+
// TMA_BULK_STORE_I4 (CUDA)
9357+
void IntrinsicLibrary::genTMABulkStoreI4(
9358+
llvm::ArrayRef<fir::ExtendedValue> args) {
9359+
assert(args.size() == 3);
9360+
mlir::Value eleSize =
9361+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9362+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9363+
fir::getBase(args[2]), eleSize);
9364+
}
9365+
9366+
// TMA_BULK_STORE_I8 (CUDA)
9367+
void IntrinsicLibrary::genTMABulkStoreI8(
9368+
llvm::ArrayRef<fir::ExtendedValue> args) {
9369+
assert(args.size() == 3);
9370+
mlir::Value eleSize =
9371+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9372+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9373+
fir::getBase(args[2]), eleSize);
9374+
}
9375+
9376+
// TMA_BULK_STORE_R2 (CUDA)
9377+
void IntrinsicLibrary::genTMABulkStoreR2(
9378+
llvm::ArrayRef<fir::ExtendedValue> args) {
9379+
assert(args.size() == 3);
9380+
mlir::Value eleSize =
9381+
builder.createIntegerConstant(loc, builder.getI32Type(), 2);
9382+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9383+
fir::getBase(args[2]), eleSize);
9384+
}
9385+
9386+
// TMA_BULK_STORE_R4 (CUDA)
9387+
void IntrinsicLibrary::genTMABulkStoreR4(
9388+
llvm::ArrayRef<fir::ExtendedValue> args) {
9389+
assert(args.size() == 3);
9390+
mlir::Value eleSize =
9391+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9392+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9393+
fir::getBase(args[2]), eleSize);
9394+
}
9395+
9396+
// TMA_BULK_STORE_R8 (CUDA)
9397+
void IntrinsicLibrary::genTMABulkStoreR8(
9398+
llvm::ArrayRef<fir::ExtendedValue> args) {
9399+
assert(args.size() == 3);
9400+
mlir::Value eleSize =
9401+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9402+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9403+
fir::getBase(args[2]), eleSize);
9404+
}
9405+
92929406
// TMA_BULK_WAIT_GROUP (CUDA)
92939407
void IntrinsicLibrary::genTMABulkWaitGroup(
92949408
llvm::ArrayRef<fir::ExtendedValue> args) {

flang/module/cudadevice.f90

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,13 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
20582058
end subroutine
20592059
end interface
20602060

2061+
! --------------------
2062+
! Bulk Store functions
2063+
! --------------------
2064+
2065+
! Generic store, count is in bytes
2066+
! --------------------------------
2067+
20612068
interface
20622069
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
20632070
!dir$ ignore_tkr src, dst
@@ -2067,6 +2074,61 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
20672074
end subroutine
20682075
end interface
20692076

2077+
2078+
! Load specific types, count is in elements
2079+
! -----------------------------------------
2080+
2081+
interface tma_bulk_store
2082+
attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems)
2083+
!dir$ ignore_tkr (r) src, (r) dst
2084+
complex(4), shared :: src(*)
2085+
complex(4), device :: dst(*)
2086+
integer(4), value :: nelems
2087+
end subroutine
2088+
2089+
attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems)
2090+
!dir$ ignore_tkr (r) src, (r) dst
2091+
complex(8), shared :: src(*)
2092+
complex(8), device :: dst(*)
2093+
integer(4), value :: nelems
2094+
end subroutine
2095+
2096+
attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems)
2097+
!dir$ ignore_tkr (r) src, (r) dst
2098+
integer(4), shared :: src(*)
2099+
integer(4), device :: dst(*)
2100+
integer(4), value :: nelems
2101+
end subroutine
2102+
2103+
attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems)
2104+
!dir$ ignore_tkr (r) src, (r) dst
2105+
integer(8), shared :: src(*)
2106+
integer(8), device :: dst(*)
2107+
integer(4), value :: nelems
2108+
end subroutine
2109+
2110+
attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems)
2111+
!dir$ ignore_tkr (r) src, (r) dst
2112+
real(2), shared :: src(*)
2113+
real(2), device :: dst(*)
2114+
integer(4), value :: nelems
2115+
end subroutine
2116+
2117+
attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems)
2118+
!dir$ ignore_tkr (r) src, (r) dst
2119+
real(4), shared :: src(*)
2120+
real(4), device :: dst(*)
2121+
integer(4), value :: nelems
2122+
end subroutine
2123+
2124+
attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems)
2125+
!dir$ ignore_tkr (r) src, (r) dst
2126+
real(8), shared :: src(*)
2127+
real(8), device :: dst(*)
2128+
integer(4), value :: nelems
2129+
end subroutine
2130+
end interface
2131+
20702132
contains
20712133

20722134
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,95 @@ end subroutine
514514

515515
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
516516
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
517+
518+
attributes(global) subroutine test_tma_bulk_store_c4(c, n)
519+
integer, value :: n
520+
complex(4), device :: c(n)
521+
complex(4), shared :: tmpa(1024)
522+
integer(4) :: j, elem_count
523+
call tma_bulk_store(tmpa, c(j), elem_count)
524+
end subroutine
525+
526+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
527+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
528+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
529+
! CHECK: nvvm.cp.async.bulk.wait_group 0
530+
531+
attributes(global) subroutine test_tma_bulk_store_c8(c, n)
532+
integer, value :: n
533+
complex(8), device :: c(n)
534+
complex(8), shared :: tmpa(1024)
535+
integer(4) :: j, elem_count
536+
call tma_bulk_store(tmpa, c(j), elem_count)
537+
end subroutine
538+
539+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
540+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
541+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
542+
! CHECK: nvvm.cp.async.bulk.wait_group 0
543+
544+
attributes(global) subroutine test_tma_bulk_store_i4(c, n)
545+
integer, value :: n
546+
integer(4), device :: c(n)
547+
integer(4), shared :: tmpa(1024)
548+
integer(4) :: j, elem_count
549+
call tma_bulk_store(tmpa, c(j), elem_count)
550+
end subroutine
551+
552+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
553+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
554+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
555+
! CHECK: nvvm.cp.async.bulk.wait_group 0
556+
557+
attributes(global) subroutine test_tma_bulk_store_i8(c, n)
558+
integer, value :: n
559+
integer(8), device :: c(n)
560+
integer(8), shared :: tmpa(1024)
561+
integer(4) :: j, elem_count
562+
call tma_bulk_store(tmpa, c(j), elem_count)
563+
end subroutine
564+
565+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
566+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
567+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
568+
! CHECK: nvvm.cp.async.bulk.wait_group 0
569+
570+
571+
attributes(global) subroutine test_tma_bulk_store_r2(c, n)
572+
integer, value :: n
573+
real(2), device :: c(n)
574+
real(2), shared :: tmpa(1024)
575+
integer(4) :: j, elem_count
576+
call tma_bulk_store(tmpa, c(j), elem_count)
577+
end subroutine
578+
579+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
580+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
581+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
582+
! CHECK: nvvm.cp.async.bulk.wait_group 0
583+
584+
attributes(global) subroutine test_tma_bulk_store_r4(c, n)
585+
integer, value :: n
586+
real(4), device :: c(n)
587+
real(4), shared :: tmpa(1024)
588+
integer(4) :: j, elem_count
589+
call tma_bulk_store(tmpa, c(j), elem_count)
590+
end subroutine
591+
592+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
593+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
594+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
595+
! CHECK: nvvm.cp.async.bulk.wait_group 0
596+
597+
attributes(global) subroutine test_tma_bulk_store_r8(c, n)
598+
integer, value :: n
599+
real(8), device :: c(n)
600+
real(8), shared :: tmpa(1024)
601+
integer(4) :: j, elem_count
602+
call tma_bulk_store(tmpa, c(j), elem_count)
603+
end subroutine
604+
605+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
606+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
607+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
608+
! CHECK: nvvm.cp.async.bulk.wait_group 0

0 commit comments

Comments
 (0)