Skip to content

Commit 94b1e2d

Browse files
clementvalaokblast
authored andcommitted
1 parent 3818838 commit 94b1e2d

File tree

4 files changed

+337
-0
lines changed

4 files changed

+337
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,13 @@ struct IntrinsicLibrary {
461461
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
462462
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
463463
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
464+
void genTMABulkLoadC4(llvm::ArrayRef<fir::ExtendedValue>);
465+
void genTMABulkLoadC8(llvm::ArrayRef<fir::ExtendedValue>);
466+
void genTMABulkLoadI4(llvm::ArrayRef<fir::ExtendedValue>);
467+
void genTMABulkLoadI8(llvm::ArrayRef<fir::ExtendedValue>);
468+
void genTMABulkLoadR2(llvm::ArrayRef<fir::ExtendedValue>);
469+
void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
470+
void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
464471
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
465472
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
466473
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,55 @@ static constexpr IntrinsicHandler handlers[]{
10451045
{"dst", asAddr},
10461046
{"nbytes", asValue}}},
10471047
/*isElemental=*/false},
1048+
{"tma_bulk_ldc4",
1049+
&I::genTMABulkLoadC4,
1050+
{{{"barrier", asAddr},
1051+
{"src", asAddr},
1052+
{"dst", asAddr},
1053+
{"nelems", asValue}}},
1054+
/*isElemental=*/false},
1055+
{"tma_bulk_ldc8",
1056+
&I::genTMABulkLoadC8,
1057+
{{{"barrier", asAddr},
1058+
{"src", asAddr},
1059+
{"dst", asAddr},
1060+
{"nelems", asValue}}},
1061+
/*isElemental=*/false},
1062+
{"tma_bulk_ldi4",
1063+
&I::genTMABulkLoadI4,
1064+
{{{"barrier", asAddr},
1065+
{"src", asAddr},
1066+
{"dst", asAddr},
1067+
{"nelems", asValue}}},
1068+
/*isElemental=*/false},
1069+
{"tma_bulk_ldi8",
1070+
&I::genTMABulkLoadI8,
1071+
{{{"barrier", asAddr},
1072+
{"src", asAddr},
1073+
{"dst", asAddr},
1074+
{"nelems", asValue}}},
1075+
/*isElemental=*/false},
1076+
{"tma_bulk_ldr2",
1077+
&I::genTMABulkLoadR2,
1078+
{{{"barrier", asAddr},
1079+
{"src", asAddr},
1080+
{"dst", asAddr},
1081+
{"nelems", asValue}}},
1082+
/*isElemental=*/false},
1083+
{"tma_bulk_ldr4",
1084+
&I::genTMABulkLoadR4,
1085+
{{{"barrier", asAddr},
1086+
{"src", asAddr},
1087+
{"dst", asAddr},
1088+
{"nelems", asValue}}},
1089+
/*isElemental=*/false},
1090+
{"tma_bulk_ldr8",
1091+
&I::genTMABulkLoadR8,
1092+
{{{"barrier", asAddr},
1093+
{"src", asAddr},
1094+
{"dst", asAddr},
1095+
{"nelems", asValue}}},
1096+
/*isElemental=*/false},
10481097
{"tma_bulk_s2g",
10491098
&I::genTMABulkS2G,
10501099
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
@@ -9278,6 +9327,93 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
92789327
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
92799328
}
92809329

9330+
static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
9331+
mlir::Value barrier, mlir::Value src,
9332+
mlir::Value dst, mlir::Value nelem,
9333+
mlir::Value eleSize) {
9334+
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
9335+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
9336+
barrier = builder.createConvert(loc, llvmPtrTy, barrier);
9337+
mlir::NVVM::InlinePtxOp::create(
9338+
builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
9339+
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
9340+
"[%1], %2, [%3];",
9341+
{});
9342+
mlir::NVVM::InlinePtxOp::create(
9343+
builder, loc, mlir::TypeRange{}, {barrier, size}, {},
9344+
"mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {});
9345+
}
9346+
9347+
// TMA_BULK_LOADC4
9348+
void IntrinsicLibrary::genTMABulkLoadC4(
9349+
llvm::ArrayRef<fir::ExtendedValue> args) {
9350+
assert(args.size() == 4);
9351+
mlir::Value eleSize =
9352+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9353+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9354+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9355+
}
9356+
9357+
// TMA_BULK_LOADC8
9358+
void IntrinsicLibrary::genTMABulkLoadC8(
9359+
llvm::ArrayRef<fir::ExtendedValue> args) {
9360+
assert(args.size() == 4);
9361+
mlir::Value eleSize =
9362+
builder.createIntegerConstant(loc, builder.getI32Type(), 16);
9363+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9364+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9365+
}
9366+
9367+
// TMA_BULK_LOADI4
9368+
void IntrinsicLibrary::genTMABulkLoadI4(
9369+
llvm::ArrayRef<fir::ExtendedValue> args) {
9370+
assert(args.size() == 4);
9371+
mlir::Value eleSize =
9372+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9373+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9374+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9375+
}
9376+
9377+
// TMA_BULK_LOADI8
9378+
void IntrinsicLibrary::genTMABulkLoadI8(
9379+
llvm::ArrayRef<fir::ExtendedValue> args) {
9380+
assert(args.size() == 4);
9381+
mlir::Value eleSize =
9382+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9383+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9384+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9385+
}
9386+
9387+
// TMA_BULK_LOADR2
9388+
void IntrinsicLibrary::genTMABulkLoadR2(
9389+
llvm::ArrayRef<fir::ExtendedValue> args) {
9390+
assert(args.size() == 4);
9391+
mlir::Value eleSize =
9392+
builder.createIntegerConstant(loc, builder.getI32Type(), 2);
9393+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9394+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9395+
}
9396+
9397+
// TMA_BULK_LOADR4
9398+
void IntrinsicLibrary::genTMABulkLoadR4(
9399+
llvm::ArrayRef<fir::ExtendedValue> args) {
9400+
assert(args.size() == 4);
9401+
mlir::Value eleSize =
9402+
builder.createIntegerConstant(loc, builder.getI32Type(), 4);
9403+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9404+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9405+
}
9406+
9407+
// TMA_BULK_LOADR8
9408+
void IntrinsicLibrary::genTMABulkLoadR8(
9409+
llvm::ArrayRef<fir::ExtendedValue> args) {
9410+
assert(args.size() == 4);
9411+
mlir::Value eleSize =
9412+
builder.createIntegerConstant(loc, builder.getI32Type(), 8);
9413+
genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
9414+
fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
9415+
}
9416+
92819417
// TMA_BULK_S2G (CUDA)
92829418
void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
92839419
assert(args.size() == 3);

flang/module/cudadevice.f90

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,67 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
20672067
end subroutine
20682068
end interface
20692069

2070+
! Load specific types, count is in elements
2071+
! -----------------------------------------
2072+
interface tma_bulk_load
2073+
attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
2074+
!dir$ ignore_tkr (r) src, (r) dst
2075+
integer(8), shared :: barrier
2076+
complex(4), device :: src(*)
2077+
complex(4), shared :: dst(*)
2078+
integer(4), value :: nelems
2079+
end subroutine
2080+
2081+
attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
2082+
!dir$ ignore_tkr (r) src, (r) dst
2083+
integer(8), shared :: barrier
2084+
complex(8), device :: src(*)
2085+
complex(8), shared :: dst(*)
2086+
integer(4), value :: nelems
2087+
end subroutine
2088+
2089+
attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
2090+
!dir$ ignore_tkr (r) src, (r) dst
2091+
integer(8), shared :: barrier
2092+
integer(4), device :: src(*)
2093+
integer(4), shared :: dst(*)
2094+
integer(4), value :: nelems
2095+
end subroutine
2096+
2097+
attributes(device) subroutine tma_bulk_ldi8(barrier, src, dst, nelems)
2098+
!dir$ ignore_tkr (r) src, (r) dst
2099+
integer(8), shared :: barrier
2100+
integer(8), device :: src(*)
2101+
integer(8), shared :: dst(*)
2102+
integer(4), value :: nelems
2103+
end subroutine
2104+
2105+
attributes(device) subroutine tma_bulk_ldr2(barrier, src, dst, nelems)
2106+
!dir$ ignore_tkr (r) src, (r) dst
2107+
integer(8), shared :: barrier
2108+
real(2), device :: src(*)
2109+
real(2), shared :: dst(*)
2110+
integer(4), value :: nelems
2111+
end subroutine
2112+
2113+
attributes(device) subroutine tma_bulk_ldr4(barrier, src, dst, nelems)
2114+
!dir$ ignore_tkr (r) src, (r) dst
2115+
integer(8), shared :: barrier
2116+
real(4), device :: src(*)
2117+
real(4), shared :: dst(*)
2118+
integer(4), value :: nelems
2119+
end subroutine
2120+
2121+
attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
2122+
!dir$ ignore_tkr (r) src, (r) dst
2123+
integer(8), shared :: barrier
2124+
real(8), device :: src(*)
2125+
real(8), shared :: dst(*)
2126+
integer(4), value :: nelems
2127+
end subroutine
2128+
end interface
2129+
2130+
20702131
contains
20712132

20722133
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,136 @@ end subroutine
516516

517517
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
518518
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
519+
520+
attributes(global) subroutine test_tma_bulk_load_c4(a, n)
521+
integer(8), shared :: barrier1
522+
integer, value :: n
523+
complex(4), device :: r8(n)
524+
complex(4), shared :: tmp(1024)
525+
integer(4) :: j, elem_count
526+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
527+
end subroutine
528+
529+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c4
530+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
531+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
532+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
533+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
534+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
535+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
536+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f32>>>, !fir.ref<complex<f32>>, i32, !llvm.ptr)
537+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
538+
539+
attributes(global) subroutine test_tma_bulk_load_c8(a, n)
540+
integer(8), shared :: barrier1
541+
integer, value :: n
542+
complex(8), device :: r8(n)
543+
complex(8), shared :: tmp(1024)
544+
integer(4) :: j, elem_count
545+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
546+
end subroutine
547+
548+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c8
549+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
550+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
551+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
552+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
553+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
554+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
555+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f64>>>, !fir.ref<complex<f64>>, i32, !llvm.ptr)
556+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
557+
558+
attributes(global) subroutine test_tma_bulk_load_i4(a, n)
559+
integer(8), shared :: barrier1
560+
integer, value :: n
561+
integer(4), device :: r8(n)
562+
integer(4), shared :: tmp(1024)
563+
integer(4) :: j, elem_count
564+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
565+
end subroutine
566+
567+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i4
568+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
569+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
570+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
571+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
572+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
573+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
574+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, i32, !llvm.ptr)
575+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
576+
577+
attributes(global) subroutine test_tma_bulk_load_i8(a, n)
578+
integer(8), shared :: barrier1
579+
integer, value :: n
580+
integer(8), device :: r8(n)
581+
integer(8), shared :: tmp(1024)
582+
integer(4) :: j, elem_count
583+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
584+
end subroutine
585+
586+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i8
587+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
588+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
589+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
590+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
591+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
592+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
593+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi64>>, !fir.ref<i64>, i32, !llvm.ptr)
594+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
595+
596+
attributes(global) subroutine test_tma_bulk_load_r2(a, n)
597+
integer(8), shared :: barrier1
598+
integer, value :: n
599+
real(2), device :: r8(n)
600+
real(2), shared :: tmp(1024)
601+
integer(4) :: j, elem_count
602+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
603+
end subroutine
604+
605+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r2
606+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r2Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
607+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r2Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
608+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
609+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
610+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
611+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
612+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf16>>, !fir.ref<f16>, i32, !llvm.ptr)
613+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
614+
615+
attributes(global) subroutine test_tma_bulk_load_r4(a, n)
616+
integer(8), shared :: barrier1
617+
integer, value :: n
618+
real(4), device :: r8(n)
619+
real(4), shared :: tmp(1024)
620+
integer(4) :: j, elem_count
621+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
622+
end subroutine
623+
624+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r4
625+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
626+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
627+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
628+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
629+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
630+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
631+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf32>>, !fir.ref<f32>, i32, !llvm.ptr)
632+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
633+
634+
attributes(global) subroutine test_tma_bulk_load_r8(a, n)
635+
integer(8), shared :: barrier1
636+
integer, value :: n
637+
real(8), device :: r8(n)
638+
real(8), shared :: tmp(1024)
639+
integer(4) :: j, elem_count
640+
call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
641+
end subroutine
642+
643+
! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r8
644+
! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
645+
! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
646+
! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
647+
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
648+
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
649+
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
650+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
651+
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)

0 commit comments

Comments
 (0)