Skip to content

Commit 17502dc

Browse files
committed
1 parent d0e0d7f commit 17502dc

File tree

4 files changed

+337
-0
lines changed

4 files changed

+337
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,13 @@ struct IntrinsicLibrary {
461461
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
462462
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
463463
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
464+
void genTMABulkLoadC4(llvm::ArrayRef<fir::ExtendedValue>);
465+
void genTMABulkLoadC8(llvm::ArrayRef<fir::ExtendedValue>);
466+
void genTMABulkLoadI4(llvm::ArrayRef<fir::ExtendedValue>);
467+
void genTMABulkLoadI8(llvm::ArrayRef<fir::ExtendedValue>);
468+
void genTMABulkLoadR2(llvm::ArrayRef<fir::ExtendedValue>);
469+
void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
470+
void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
464471
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
465472
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
466473
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,55 @@ static constexpr IntrinsicHandler handlers[]{
10451045
{"dst", asAddr},
10461046
{"nbytes", asValue}}},
10471047
/*isElemental=*/false},
1048+
{"tma_bulk_ldc4",
1049+
&I::genTMABulkLoadC4,
1050+
{{{"barrier", asAddr},
1051+
{"src", asAddr},
1052+
{"dst", asAddr},
1053+
{"nelems", asValue}}},
1054+
/*isElemental=*/false},
1055+
{"tma_bulk_ldc8",
1056+
&I::genTMABulkLoadC8,
1057+
{{{"barrier", asAddr},
1058+
{"src", asAddr},
1059+
{"dst", asAddr},
1060+
{"nelems", asValue}}},
1061+
/*isElemental=*/false},
1062+
{"tma_bulk_ldi4",
1063+
&I::genTMABulkLoadI4,
1064+
{{{"barrier", asAddr},
1065+
{"src", asAddr},
1066+
{"dst", asAddr},
1067+
{"nelems", asValue}}},
1068+
/*isElemental=*/false},
1069+
{"tma_bulk_ldi8",
1070+
&I::genTMABulkLoadI8,
1071+
{{{"barrier", asAddr},
1072+
{"src", asAddr},
1073+
{"dst", asAddr},
1074+
{"nelems", asValue}}},
1075+
/*isElemental=*/false},
1076+
{"tma_bulk_ldr2",
1077+
&I::genTMABulkLoadR2,
1078+
{{{"barrier", asAddr},
1079+
{"src", asAddr},
1080+
{"dst", asAddr},
1081+
{"nelems", asValue}}},
1082+
/*isElemental=*/false},
1083+
{"tma_bulk_ldr4",
1084+
&I::genTMABulkLoadR4,
1085+
{{{"barrier", asAddr},
1086+
{"src", asAddr},
1087+
{"dst", asAddr},
1088+
{"nelems", asValue}}},
1089+
/*isElemental=*/false},
1090+
{"tma_bulk_ldr8",
1091+
&I::genTMABulkLoadR8,
1092+
{{{"barrier", asAddr},
1093+
{"src", asAddr},
1094+
{"dst", asAddr},
1095+
{"nelems", asValue}}},
1096+
/*isElemental=*/false},
10481097
{"tma_bulk_s2g",
10491098
&I::genTMABulkS2G,
10501099
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
@@ -9278,6 +9327,93 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
92789327
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
92799328
}
92809329

9330+
static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
9331+
mlir::Value barrier, mlir::Value src,
9332+
mlir::Value dst, mlir::Value nelem,
9333+
mlir::Value eleSize) {
9334+
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
9335+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
9336+
barrier = builder.createConvert(loc, llvmPtrTy, barrier);
9337+
mlir::NVVM::InlinePtxOp::create(
9338+
builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
9339+
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
9340+
"[%1], %2, [%3];",
9341+
{});
9342+
mlir::NVVM::InlinePtxOp::create(
9343+
builder, loc, mlir::TypeRange{}, {barrier, size}, {},
9344+
"mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {});
9345+
}
9346+
9347+
// TMA_BULK_LOADC4
9348+
void IntrinsicLibrary::genTMABulkLoadC4(
9349+
llvm::ArrayRef<fir::ExtendedValue> args) {
9350+
assert(args.size() == 4);
9351+
mlir::Value eleSize =
9352+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9353+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9354+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9355+
}
9356+
9357+
// TMA_BULK_LOADC8
9358+
void IntrinsicLibrary::genTMABulkLoadC8(
9359+
llvm::ArrayRef<fir::ExtendedValue> args) {
9360+
assert(args.size() == 4);
9361+
mlir::Value eleSize =
9362+
builder.createIntegerConstant(loc, builder.getI32Type(), 16);
9363+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9364+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9365+
}
9366+
9367+
// TMA_BULK_LOADI4
9368+
void IntrinsicLibrary::genTMABulkLoadI4(
9369+
llvm::ArrayRef<fir::ExtendedValue> args) {
9370+
assert(args.size() == 4);
9371+
mlir::Value eleSize =
9372+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9373+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9374+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9375+
}
9376+
9377+
// TMA_BULK_LOADI8
9378+
void IntrinsicLibrary::genTMABulkLoadI8(
9379+
llvm::ArrayRef<fir::ExtendedValue> args) {
9380+
assert(args.size() == 4);
9381+
mlir::Value eleSize =
9382+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9383+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9384+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9385+
}
9386+
9387+
// TMA_BULK_LOADR2
9388+
void IntrinsicLibrary::genTMABulkLoadR2(
9389+
llvm::ArrayRef<fir::ExtendedValue> args) {
9390+
assert(args.size() == 4);
9391+
mlir::Value eleSize =
9392+
builder.createIntegerConstant(loc, builder.getI32Type(), 2);
9393+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9394+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9395+
}
9396+
9397+
// TMA_BULK_LOADR4
9398+
void IntrinsicLibrary::genTMABulkLoadR4(
9399+
llvm::ArrayRef<fir::ExtendedValue> args) {
9400+
assert(args.size() == 4);
9401+
mlir::Value eleSize =
9402+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9403+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9404+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9405+
}
9406+
9407+
// TMA_BULK_LOADR8
9408+
void IntrinsicLibrary::genTMABulkLoadR8(
9409+
llvm::ArrayRef<fir::ExtendedValue> args) {
9410+
assert(args.size() == 4);
9411+
mlir::Value eleSize =
9412+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9413+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9414+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9415+
}
9416+
92819417
// TMA_BULK_S2G (CUDA)
92829418
void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
92839419
assert(args.size() == 3);

flang/module/cudadevice.f90

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,67 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
20672067
end subroutine
20682068
end interface
20692069

2070+
! Load specific types, count is in elements
2071+
! -----------------------------------------
2072+
interface tma_bulk_load
2073+
attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
2074+
!dir$ ignore_tkr (r) src, (r) dst
2075+
integer(8), shared :: barrier
2076+
integer(4), device :: src(*)
2077+
integer(4), shared :: dst(*)
2078+
integer(4), value :: nelems
2079+
end subroutine
2080+
2081+
attributes(device) subroutine tma_bulk_ldi8(barrier, src, dst, nelems)
2082+
!dir$ ignore_tkr (r) src, (r) dst
2083+
integer(8), shared :: barrier
2084+
integer(8), device :: src(*)
2085+
integer(8), shared :: dst(*)
2086+
integer(4), value :: nelems
2087+
end subroutine
2088+
2089+
attributes(device) subroutine tma_bulk_ldr2(barrier, src, dst, nelems)
2090+
!dir$ ignore_tkr (r) src, (r) dst
2091+
integer(8), shared :: barrier
2092+
real(2), device :: src(*)
2093+
real(2), shared :: dst(*)
2094+
integer(4), value :: nelems
2095+
end subroutine
2096+
2097+
attributes(device) subroutine tma_bulk_ldr4(barrier, src, dst, nelems)
2098+
!dir$ ignore_tkr (r) src, (r) dst
2099+
integer(8), shared :: barrier
2100+
real(4), device :: src(*)
2101+
real(4), shared :: dst(*)
2102+
integer(4), value :: nelems
2103+
end subroutine
2104+
2105+
attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
2106+
!dir$ ignore_tkr (r) src, (r) dst
2107+
integer(8), shared :: barrier
2108+
real(8), device :: src(*)
2109+
real(8), shared :: dst(*)
2110+
integer(4), value :: nelems
2111+
end subroutine
2112+
2113+
attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
2114+
!dir$ ignore_tkr (r) src, (r) dst
2115+
integer(8), shared :: barrier
2116+
complex(4), device :: src(*)
2117+
complex(4), shared :: dst(*)
2118+
integer(4), value :: nelems
2119+
end subroutine
2120+
2121+
attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
2122+
!dir$ ignore_tkr (r) src, (r) dst
2123+
integer(8), shared :: barrier
2124+
complex(8), device :: src(*)
2125+
complex(8), shared :: dst(*)
2126+
integer(4), value :: nelems
2127+
end subroutine
2128+
end interface
2129+
2130+
20702131
contains
20712132

20722133
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,136 @@ end subroutine
514514

515515
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
516516
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
517+
518+
attributes(global) subroutine test_tma_bulk_load_c4(a, n)
519+
integer(8), shared :: barrier1
520+
integer, value :: n
521+
complex(4), device :: r8(n)
522+
complex(4), shared :: tmp(1024)
523+
integer(4) :: j, elem_count
524+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
525+
end subroutine
526+
527+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c4
528+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
529+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
530+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
531+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
532+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
533+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
534+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f32>>>, !fir.ref<complex<f32>>, i32, !llvm.ptr)
535+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
536+
537+
attributes(global) subroutine test_tma_bulk_load_c8(a, n)
538+
integer(8), shared :: barrier1
539+
integer, value :: n
540+
complex(8), device :: r8(n)
541+
complex(8), shared :: tmp(1024)
542+
integer(4) :: j, elem_count
543+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
544+
end subroutine
545+
546+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c8
547+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
548+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
549+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
550+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
551+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
552+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
553+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f64>>>, !fir.ref<complex<f64>>, i32, !llvm.ptr)
554+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
555+
556+
attributes(global) subroutine test_tma_bulk_load_i4(a, n)
557+
integer(8), shared :: barrier1
558+
integer, value :: n
559+
integer(4), device :: r8(n)
560+
integer(4), shared :: tmp(1024)
561+
integer(4) :: j, elem_count
562+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
563+
end subroutine
564+
565+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i4
566+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
567+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
568+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
569+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
570+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
571+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
572+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, i32, !llvm.ptr)
573+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
574+
575+
attributes(global) subroutine test_tma_bulk_load_i8(a, n)
576+
integer(8), shared :: barrier1
577+
integer, value :: n
578+
integer(8), device :: r8(n)
579+
integer(8), shared :: tmp(1024)
580+
integer(4) :: j, elem_count
581+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
582+
end subroutine
583+
584+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i8
585+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
586+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
587+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
588+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
589+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
590+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
591+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi64>>, !fir.ref<i64>, i32, !llvm.ptr)
592+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
593+
594+
attributes(global) subroutine test_tma_bulk_load_r2(a, n)
595+
integer(8), shared :: barrier1
596+
integer, value :: n
597+
real(2), device :: r8(n)
598+
real(2), shared :: tmp(1024)
599+
integer(4) :: j, elem_count
600+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
601+
end subroutine
602+
603+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r2
604+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r2Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
605+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r2Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
606+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
607+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
608+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
609+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
610+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf16>>, !fir.ref<f16>, i32, !llvm.ptr)
611+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
612+
613+
attributes(global) subroutine test_tma_bulk_load_r4(a, n)
614+
integer(8), shared :: barrier1
615+
integer, value :: n
616+
real(4), device :: r8(n)
617+
real(4), shared :: tmp(1024)
618+
integer(4) :: j, elem_count
619+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
620+
end subroutine
621+
622+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r4
623+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
624+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
625+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
626+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
627+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
628+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
629+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf32>>, !fir.ref<f32>, i32, !llvm.ptr)
630+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
631+
632+
attributes(global) subroutine test_tma_bulk_load_r8(a, n)
633+
integer(8), shared :: barrier1
634+
integer, value :: n
635+
real(8), device :: r8(n)
636+
real(8), shared :: tmp(1024)
637+
integer(4) :: j, elem_count
638+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
639+
end subroutine
640+
641+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r8
642+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
643+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
644+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
645+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
646+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
647+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
648+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
649+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)

0 commit comments

Comments
 (0)