diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index f5ff6626da654..3407dd01dd504 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -469,6 +469,13 @@ struct IntrinsicLibrary { void genTMABulkLoadR4(llvm::ArrayRef); void genTMABulkLoadR8(llvm::ArrayRef); void genTMABulkS2G(llvm::ArrayRef); + void genTMABulkStoreI4(llvm::ArrayRef); + void genTMABulkStoreI8(llvm::ArrayRef); + void genTMABulkStoreR2(llvm::ArrayRef); + void genTMABulkStoreR4(llvm::ArrayRef); + void genTMABulkStoreR8(llvm::ArrayRef); + void genTMABulkStoreC4(llvm::ArrayRef); + void genTMABulkStoreC8(llvm::ArrayRef); void genTMABulkWaitGroup(llvm::ArrayRef); mlir::Value genTrailz(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genTransfer(mlir::Type, diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 65317599ecd35..53fe9c0d2f6f0 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1098,6 +1098,34 @@ static constexpr IntrinsicHandler handlers[]{ &I::genTMABulkS2G, {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}}, /*isElemental=*/false}, + {"tma_bulk_store_c4", + &I::genTMABulkStoreC4, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_c8", + &I::genTMABulkStoreC8, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_i4", + &I::genTMABulkStoreI4, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_i8", + &I::genTMABulkStoreI8, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r2", + &I::genTMABulkStoreR2, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r4", + &I::genTMABulkStoreR4, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, + {"tma_bulk_store_r8", + &I::genTMABulkStoreR8, + {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, + /*isElemental=*/false}, {"tma_bulk_wait_group", &I::genTMABulkWaitGroup, {{}}, @@ -9430,6 +9458,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef args) { builder.getI32IntegerAttr(0), {}); } +static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value src, mlir::Value dst, mlir::Value count, + mlir::Value eleSize) { + mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count); + src = convertPtrToNVVMSpace(builder, loc, src, + mlir::NVVM::NVVMMemorySpace::Shared); + dst = convertPtrToNVVMSpace(builder, loc, dst, + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src, + size, {}, {}); + mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {}, + "cp.async.bulk.commit_group", {}); + mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, + builder.getI32IntegerAttr(0), {}); +} + +// TMA_BULK_STORE_C4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreC4( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_C8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreC8( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 16); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreI4( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreI8( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R2 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR2( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 2); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR4( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 4); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR8( + llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value eleSize = + builder.createIntegerConstant(loc, builder.getI32Type(), 8); + genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), + fir::getBase(args[2]), eleSize); +} + // TMA_BULK_WAIT_GROUP (CUDA) void IntrinsicLibrary::genTMABulkWaitGroup( llvm::ArrayRef args) { diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index e6decbb96c271..59af58ddcd32e 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -2047,7 +2047,13 @@ attributes(device) subroutine tma_bulk_wait_group() end subroutine end interface + ! -------------------- + ! Bulk load functions + ! -------------------- + ! Generic load, count is in bytes + ! ------------------------------- + interface attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes) !dir$ ignore_tkr src, dst @@ -2058,17 +2064,9 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes) end subroutine end interface - interface - attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes) - !dir$ ignore_tkr src, dst - integer(4), shared :: src(*) - integer(4), device :: dst(*) - integer(4), value :: nbytes - end subroutine - end interface - ! Load specific types, count is in elements ! ----------------------------------------- + interface tma_bulk_load attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems) !dir$ ignore_tkr (r) src, (r) dst @@ -2127,6 +2125,75 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems) end subroutine end interface + ! -------------------- + ! Bulk Store functions + ! -------------------- + + ! Generic store, count is in bytes + ! -------------------------------- + + interface + attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes) + !dir$ ignore_tkr src, dst + integer(4), shared :: src(*) + integer(4), device :: dst(*) + integer(4), value :: nbytes + end subroutine + end interface + + ! Load specific types, count is in elements + ! ----------------------------------------- + + interface tma_bulk_store + attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + complex(4), shared :: src(*) + complex(4), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + complex(8), shared :: src(*) + complex(8), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + integer(4), shared :: src(*) + integer(4), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + integer(8), shared :: src(*) + integer(8), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + real(2), shared :: src(*) + real(2), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + real(4), shared :: src(*) + real(4), device :: dst(*) + integer(4), value :: nelems + end subroutine + + attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems) + !dir$ ignore_tkr (r) src, (r) dst + real(8), shared :: src(*) + real(8), device :: dst(*) + integer(4), value :: nelems + end subroutine + end interface contains diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 71d3d1ef2e2e9..8f355217899b3 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -649,3 +649,95 @@ end subroutine ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref) -> !llvm.ptr ! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref>, !fir.ref, i32, !llvm.ptr) ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32) + +attributes(global) subroutine test_tma_bulk_store_c4(c, n) + integer, value :: n + complex(4), device :: c(n) + complex(4), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + +attributes(global) subroutine test_tma_bulk_store_c8(c, n) + integer, value :: n + complex(8), device :: c(n) + complex(8), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + +attributes(global) subroutine test_tma_bulk_store_i4(c, n) + integer, value :: n + integer(4), device :: c(n) + integer(4), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + +attributes(global) subroutine test_tma_bulk_store_i8(c, n) + integer, value :: n + integer(8), device :: c(n) + integer(8), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + + +attributes(global) subroutine test_tma_bulk_store_r2(c, n) + integer, value :: n + real(2), device :: c(n) + real(2), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + +attributes(global) subroutine test_tma_bulk_store_r4(c, n) + integer, value :: n + real(4), device :: c(n) + real(4), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0 + +attributes(global) subroutine test_tma_bulk_store_r8(c, n) + integer, value :: n + real(8), device :: c(n) + real(8), shared :: tmpa(1024) + integer(4) :: j, elem_count + call tma_bulk_store(tmpa, c(j), elem_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8 +! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3> +! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group" +! CHECK: nvvm.cp.async.bulk.wait_group 0