Skip to content

Commit 7f939af

Browse files
clementvalaokblast
authored andcommitted
1 parent 94b1e2d commit 7f939af

File tree

4 files changed

+289
-9
lines changed

4 files changed

+289
-9
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,13 @@ struct IntrinsicLibrary {
469469
void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
470470
void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
471471
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
472+
void genTMABulkStoreI4(llvm::ArrayRef<fir::ExtendedValue>);
473+
void genTMABulkStoreI8(llvm::ArrayRef<fir::ExtendedValue>);
474+
void genTMABulkStoreR2(llvm::ArrayRef<fir::ExtendedValue>);
475+
void genTMABulkStoreR4(llvm::ArrayRef<fir::ExtendedValue>);
476+
void genTMABulkStoreR8(llvm::ArrayRef<fir::ExtendedValue>);
477+
void genTMABulkStoreC4(llvm::ArrayRef<fir::ExtendedValue>);
478+
void genTMABulkStoreC8(llvm::ArrayRef<fir::ExtendedValue>);
472479
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
473480
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
474481
fir::ExtendedValue genTransfer(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,34 @@ static constexpr IntrinsicHandler handlers[]{
10981098
&I::genTMABulkS2G,
10991099
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
11001100
/*isElemental=*/false},
1101+
{"tma_bulk_store_c4",
1102+
&I::genTMABulkStoreC4,
1103+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1104+
/*isElemental=*/false},
1105+
{"tma_bulk_store_c8",
1106+
&I::genTMABulkStoreC8,
1107+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1108+
/*isElemental=*/false},
1109+
{"tma_bulk_store_i4",
1110+
&I::genTMABulkStoreI4,
1111+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1112+
/*isElemental=*/false},
1113+
{"tma_bulk_store_i8",
1114+
&I::genTMABulkStoreI8,
1115+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1116+
/*isElemental=*/false},
1117+
{"tma_bulk_store_r2",
1118+
&I::genTMABulkStoreR2,
1119+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1120+
/*isElemental=*/false},
1121+
{"tma_bulk_store_r4",
1122+
&I::genTMABulkStoreR4,
1123+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1124+
/*isElemental=*/false},
1125+
{"tma_bulk_store_r8",
1126+
&I::genTMABulkStoreR8,
1127+
{{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
1128+
/*isElemental=*/false},
11011129
{"tma_bulk_wait_group",
11021130
&I::genTMABulkWaitGroup,
11031131
{{}},
@@ -9430,6 +9458,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
94309458
builder.getI32IntegerAttr(0), {});
94319459
}
94329460

9461+
static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
9462+
mlir::Value src, mlir::Value dst, mlir::Value count,
9463+
mlir::Value eleSize) {
9464+
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
9465+
src = convertPtrToNVVMSpace(builder, loc, src,
9466+
mlir::NVVM::NVVMMemorySpace::Shared);
9467+
dst = convertPtrToNVVMSpace(builder, loc, dst,
9468+
mlir::NVVM::NVVMMemorySpace::Global);
9469+
mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
9470+
size, {}, {});
9471+
mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
9472+
"cp.async.bulk.commit_group", {});
9473+
mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
9474+
builder.getI32IntegerAttr(0), {});
9475+
}
9476+
9477+
// TMA_BULK_STORE_C4 (CUDA)
9478+
void IntrinsicLibrary::genTMABulkStoreC4(
9479+
llvm::ArrayRef<fir::ExtendedValue> args) {
9480+
assert(args.size() == 3);
9481+
mlir::Value eleSize =
9482+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9483+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9484+
fir::getBase(args[2]), eleSize);
9485+
}
9486+
9487+
// TMA_BULK_STORE_C8 (CUDA)
9488+
void IntrinsicLibrary::genTMABulkStoreC8(
9489+
llvm::ArrayRef<fir::ExtendedValue> args) {
9490+
assert(args.size() == 3);
9491+
mlir::Value eleSize =
9492+
builder.createIntegerConstant(loc, builder.getI32Type(), 16);
9493+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9494+
fir::getBase(args[2]), eleSize);
9495+
}
9496+
9497+
// TMA_BULK_STORE_I4 (CUDA)
9498+
void IntrinsicLibrary::genTMABulkStoreI4(
9499+
llvm::ArrayRef<fir::ExtendedValue> args) {
9500+
assert(args.size() == 3);
9501+
mlir::Value eleSize =
9502+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9503+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9504+
fir::getBase(args[2]), eleSize);
9505+
}
9506+
9507+
// TMA_BULK_STORE_I8 (CUDA)
9508+
void IntrinsicLibrary::genTMABulkStoreI8(
9509+
llvm::ArrayRef<fir::ExtendedValue> args) {
9510+
assert(args.size() == 3);
9511+
mlir::Value eleSize =
9512+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9513+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9514+
fir::getBase(args[2]), eleSize);
9515+
}
9516+
9517+
// TMA_BULK_STORE_R2 (CUDA)
9518+
void IntrinsicLibrary::genTMABulkStoreR2(
9519+
llvm::ArrayRef<fir::ExtendedValue> args) {
9520+
assert(args.size() == 3);
9521+
mlir::Value eleSize =
9522+
builder.createIntegerConstant(loc, builder.getI32Type(), 2);
9523+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9524+
fir::getBase(args[2]), eleSize);
9525+
}
9526+
9527+
// TMA_BULK_STORE_R4 (CUDA)
9528+
void IntrinsicLibrary::genTMABulkStoreR4(
9529+
llvm::ArrayRef<fir::ExtendedValue> args) {
9530+
assert(args.size() == 3);
9531+
mlir::Value eleSize =
9532+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9533+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9534+
fir::getBase(args[2]), eleSize);
9535+
}
9536+
9537+
// TMA_BULK_STORE_R8 (CUDA)
9538+
void IntrinsicLibrary::genTMABulkStoreR8(
9539+
llvm::ArrayRef<fir::ExtendedValue> args) {
9540+
assert(args.size() == 3);
9541+
mlir::Value eleSize =
9542+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9543+
genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9544+
fir::getBase(args[2]), eleSize);
9545+
}
9546+
94339547
// TMA_BULK_WAIT_GROUP (CUDA)
94349548
void IntrinsicLibrary::genTMABulkWaitGroup(
94359549
llvm::ArrayRef<fir::ExtendedValue> args) {

flang/module/cudadevice.f90

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,7 +2047,13 @@ attributes(device) subroutine tma_bulk_wait_group()
20472047
end subroutine
20482048
end interface
20492049

2050+
! --------------------
2051+
! Bulk load functions
2052+
! --------------------
2053+
20502054
! Generic load, count is in bytes
2055+
! -------------------------------
2056+
20512057
interface
20522058
attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
20532059
!dir$ ignore_tkr src, dst
@@ -2058,17 +2064,9 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
20582064
end subroutine
20592065
end interface
20602066

2061-
interface
2062-
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
2063-
!dir$ ignore_tkr src, dst
2064-
integer(4), shared :: src(*)
2065-
integer(4), device :: dst(*)
2066-
integer(4), value :: nbytes
2067-
end subroutine
2068-
end interface
2069-
20702067
! Load specific types, count is in elements
20712068
! -----------------------------------------
2069+
20722070
interface tma_bulk_load
20732071
attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
20742072
!dir$ ignore_tkr (r) src, (r) dst
@@ -2127,6 +2125,75 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
21272125
end subroutine
21282126
end interface
21292127

2128+
! --------------------
2129+
! Bulk Store functions
2130+
! --------------------
2131+
2132+
! Generic store, count is in bytes
2133+
! --------------------------------
2134+
2135+
interface
2136+
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
2137+
!dir$ ignore_tkr src, dst
2138+
integer(4), shared :: src(*)
2139+
integer(4), device :: dst(*)
2140+
integer(4), value :: nbytes
2141+
end subroutine
2142+
end interface
2143+
2144+
! Load specific types, count is in elements
2145+
! -----------------------------------------
2146+
2147+
interface tma_bulk_store
2148+
attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems)
2149+
!dir$ ignore_tkr (r) src, (r) dst
2150+
complex(4), shared :: src(*)
2151+
complex(4), device :: dst(*)
2152+
integer(4), value :: nelems
2153+
end subroutine
2154+
2155+
attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems)
2156+
!dir$ ignore_tkr (r) src, (r) dst
2157+
complex(8), shared :: src(*)
2158+
complex(8), device :: dst(*)
2159+
integer(4), value :: nelems
2160+
end subroutine
2161+
2162+
attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems)
2163+
!dir$ ignore_tkr (r) src, (r) dst
2164+
integer(4), shared :: src(*)
2165+
integer(4), device :: dst(*)
2166+
integer(4), value :: nelems
2167+
end subroutine
2168+
2169+
attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems)
2170+
!dir$ ignore_tkr (r) src, (r) dst
2171+
integer(8), shared :: src(*)
2172+
integer(8), device :: dst(*)
2173+
integer(4), value :: nelems
2174+
end subroutine
2175+
2176+
attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems)
2177+
!dir$ ignore_tkr (r) src, (r) dst
2178+
real(2), shared :: src(*)
2179+
real(2), device :: dst(*)
2180+
integer(4), value :: nelems
2181+
end subroutine
2182+
2183+
attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems)
2184+
!dir$ ignore_tkr (r) src, (r) dst
2185+
real(4), shared :: src(*)
2186+
real(4), device :: dst(*)
2187+
integer(4), value :: nelems
2188+
end subroutine
2189+
2190+
attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems)
2191+
!dir$ ignore_tkr (r) src, (r) dst
2192+
real(8), shared :: src(*)
2193+
real(8), device :: dst(*)
2194+
integer(4), value :: nelems
2195+
end subroutine
2196+
end interface
21302197

21312198
contains
21322199

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,3 +649,95 @@ end subroutine
649649
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
650650
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
651651
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
652+
653+
attributes(global) subroutine test_tma_bulk_store_c4(c, n)
654+
integer, value :: n
655+
complex(4), device :: c(n)
656+
complex(4), shared :: tmpa(1024)
657+
integer(4) :: j, elem_count
658+
call tma_bulk_store(tmpa, c(j), elem_count)
659+
end subroutine
660+
661+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
662+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
663+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
664+
! CHECK: nvvm.cp.async.bulk.wait_group 0
665+
666+
attributes(global) subroutine test_tma_bulk_store_c8(c, n)
667+
integer, value :: n
668+
complex(8), device :: c(n)
669+
complex(8), shared :: tmpa(1024)
670+
integer(4) :: j, elem_count
671+
call tma_bulk_store(tmpa, c(j), elem_count)
672+
end subroutine
673+
674+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
675+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
676+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
677+
! CHECK: nvvm.cp.async.bulk.wait_group 0
678+
679+
attributes(global) subroutine test_tma_bulk_store_i4(c, n)
680+
integer, value :: n
681+
integer(4), device :: c(n)
682+
integer(4), shared :: tmpa(1024)
683+
integer(4) :: j, elem_count
684+
call tma_bulk_store(tmpa, c(j), elem_count)
685+
end subroutine
686+
687+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
688+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
689+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
690+
! CHECK: nvvm.cp.async.bulk.wait_group 0
691+
692+
attributes(global) subroutine test_tma_bulk_store_i8(c, n)
693+
integer, value :: n
694+
integer(8), device :: c(n)
695+
integer(8), shared :: tmpa(1024)
696+
integer(4) :: j, elem_count
697+
call tma_bulk_store(tmpa, c(j), elem_count)
698+
end subroutine
699+
700+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
701+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
702+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
703+
! CHECK: nvvm.cp.async.bulk.wait_group 0
704+
705+
706+
attributes(global) subroutine test_tma_bulk_store_r2(c, n)
707+
integer, value :: n
708+
real(2), device :: c(n)
709+
real(2), shared :: tmpa(1024)
710+
integer(4) :: j, elem_count
711+
call tma_bulk_store(tmpa, c(j), elem_count)
712+
end subroutine
713+
714+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
715+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
716+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
717+
! CHECK: nvvm.cp.async.bulk.wait_group 0
718+
719+
attributes(global) subroutine test_tma_bulk_store_r4(c, n)
720+
integer, value :: n
721+
real(4), device :: c(n)
722+
real(4), shared :: tmpa(1024)
723+
integer(4) :: j, elem_count
724+
call tma_bulk_store(tmpa, c(j), elem_count)
725+
end subroutine
726+
727+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
728+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
729+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
730+
! CHECK: nvvm.cp.async.bulk.wait_group 0
731+
732+
attributes(global) subroutine test_tma_bulk_store_r8(c, n)
733+
integer, value :: n
734+
real(8), device :: c(n)
735+
real(8), shared :: tmpa(1024)
736+
integer(4) :: j, elem_count
737+
call tma_bulk_store(tmpa, c(j), elem_count)
738+
end subroutine
739+
740+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
741+
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
742+
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
743+
! CHECK: nvvm.cp.async.bulk.wait_group 0

0 commit comments

Comments
 (0)