From d9493814270d08caec643cde50e4740910ed9ecf Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 25 Jun 2025 12:03:50 -0700 Subject: [PATCH 1/4] [flang][cuda] Inline this_grid computation --- .../flang/Optimizer/Builder/IntrinsicCall.h | 1 + flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 85 +++++++++++++++++++ flang/module/cooperative_groups.f90 | 30 +++++++ flang/test/Lower/CUDA/cuda-cooperative.cuf | 53 ++++++++++++ flang/tools/f18/CMakeLists.txt | 14 +-- 5 files changed, 178 insertions(+), 5 deletions(-) create mode 100644 flang/module/cooperative_groups.f90 create mode 100644 flang/test/Lower/CUDA/cuda-cooperative.cuf diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 17052113859e1..3cf7a4920ed7d 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -442,6 +442,7 @@ struct IntrinsicLibrary { llvm::ArrayRef); fir::ExtendedValue genTranspose(mlir::Type, llvm::ArrayRef); + mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef); void genThreadFence(llvm::ArrayRef); void genThreadFenceBlock(llvm::ArrayRef); void genThreadFenceSystem(llvm::ArrayRef); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 178b6770d6b53..a00dc9f5d30a2 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{ {{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}}, /*isElemental=*/false}, {"tand", &I::genTand}, + {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false}, {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false}, {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false}, {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false}, @@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType, return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg}); } +// THIS_GRID +mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 0); + auto recTy = mlir::cast(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = builder.create(loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + + mlir::Value threadIdX = builder.create(loc, i32Ty); + mlir::Value threadIdY = builder.create(loc, i32Ty); + mlir::Value threadIdZ = builder.create(loc, i32Ty); + + mlir::Value blockIdX = builder.create(loc, i32Ty); + mlir::Value blockIdY = builder.create(loc, i32Ty); + mlir::Value blockIdZ = builder.create(loc, i32Ty); + + mlir::Value blockDimX = builder.create(loc, i32Ty); + mlir::Value blockDimY = builder.create(loc, i32Ty); + mlir::Value blockDimZ = builder.create(loc, i32Ty); + mlir::Value gridDimX = builder.create(loc, i32Ty); + mlir::Value gridDimY = builder.create(loc, i32Ty); + mlir::Value gridDimZ = builder.create(loc, i32Ty); + + // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) * + // (blockDim.x * gridDim.x); + mlir::Value resZ = + builder.create(loc, blockDimZ, gridDimZ); + mlir::Value resY = + builder.create(loc, blockDimY, gridDimY); + mlir::Value resX = + builder.create(loc, blockDimX, gridDimX); + mlir::Value resZY = builder.create(loc, resZ, resY); + mlir::Value size = builder.create(loc, resZY, resX); + + // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) + + // blockIdx.x; + // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) + + // ((threadIdx.z * blockDim.y) * blockDim.x) + + // (threadIdx.y * blockDim.x) + threadIdx.x + 1; + mlir::Value r1 = builder.create(loc, blockIdZ, gridDimY); + mlir::Value r2 = builder.create(loc, r1, gridDimX); + mlir::Value r3 = builder.create(loc, blockIdY, gridDimX); + mlir::Value r2r3 = builder.create(loc, r2, r3); + mlir::Value tmp = builder.create(loc, r2r3, blockIdX); + + mlir::Value bXbY = + builder.create(loc, blockDimX, blockDimY); + mlir::Value bXbYbZ = + builder.create(loc, bXbY, blockDimZ); + mlir::Value tZbY = + builder.create(loc, threadIdZ, blockDimY); + mlir::Value tZbYbX = + builder.create(loc, tZbY, blockDimX); + mlir::Value tYbX = + builder.create(loc, threadIdY, blockDimX); + mlir::Value rank = builder.create(loc, tmp, bXbYbZ); + rank = builder.create(loc, rank, tZbYbX); + rank = builder.create(loc, rank, tYbX); + rank = builder.create(loc, rank, threadIdX); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + rank = builder.create(loc, rank, one); + + auto sizeFieldName = recTy.getTypeList()[1].first; + mlir::Type sizeFieldTy = recTy.getTypeList()[1].second; + mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); + mlir::Value sizeFieldIndex = builder.create( + loc, fieldIndexType, sizeFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value sizeCoord = builder.create( + loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex); + builder.create(loc, size, sizeCoord); + + auto rankFieldName = recTy.getTypeList()[2].first; + mlir::Type rankFieldTy = recTy.getTypeList()[2].second; + mlir::Value rankFieldIndex = builder.create( + loc, fieldIndexType, rankFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value rankCoord = builder.create( + loc, builder.getRefType(rankFieldTy), res, rankFieldIndex); + builder.create(loc, rank, rankCoord); + return res; +} + // TRAILZ mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType, llvm::ArrayRef args) { diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90 new file mode 100644 index 0000000000000..e7d19f1c65b1a --- /dev/null +++ b/flang/module/cooperative_groups.f90 @@ -0,0 +1,30 @@ +!===-- module/cudedevice.f90 -----------------------------------------------===! +! +! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +! See https://llvm.org/LICENSE.txt for license information. +! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +! +!===------------------------------------------------------------------------===! + +! CUDA Fortran cooperative groups + +module cooperative_groups + +use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr + +implicit none + +type :: grid_group + type(c_devptr), private :: handle + integer(4) :: size + integer(4) :: rank +end type grid_group + +interface + attributes(device) function this_grid() + import + type(grid_group) :: this_grid + end function +end interface + +end module diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf new file mode 100644 index 0000000000000..d3deb8f318664 --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf @@ -0,0 +1,53 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Test CUDA Fortran procedures available in cooperative_groups module. + +attributes(grid_global) subroutine g1() + use cooperative_groups + type(grid_group) :: gg + gg = this_grid() +end subroutine + +! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}> +! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}> +! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32 +! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32 +! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32 +! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32 +! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32 +! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32 +! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32 +! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32 +! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32 +! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32 +! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32 +! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32 + +! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32 +! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32 +! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32 +! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32 +! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32 + +! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32 +! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32 +! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32 +! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32 +! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32 + +! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32 +! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32 +! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32 +! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32 +! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32 +! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32 +! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32 +! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32 +! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32 +! CHECK: %[[ONE:.*]] = arith.constant 1 : i32 +! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32 +! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref +! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref +! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref +! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref + diff --git a/flang/tools/f18/CMakeLists.txt b/flang/tools/f18/CMakeLists.txt index fb5510d7163d1..95fe9ad87ac45 100644 --- a/flang/tools/f18/CMakeLists.txt +++ b/flang/tools/f18/CMakeLists.txt @@ -15,6 +15,7 @@ set(MODULES "mma" "__cuda_builtins" "__cuda_device" + "cooperative_groups" "cudadevice" "ieee_arithmetic" "ieee_exceptions" @@ -60,12 +61,15 @@ if (NOT CMAKE_CROSSCOMPILING) elseif(${filename} STREQUAL "__ppc_intrinsics" OR ${filename} STREQUAL "mma") set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod) - elseif(${filename} STREQUAL "__cuda_device") + elseif(${filename} STREQUAL "__cuda_device" OR + ${filename} STREQUAL "cudadevice" OR + ${filename} STREQUAL "cooperative_groups") set(opts -fc1 -xcuda) - set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod) - elseif(${filename} STREQUAL "cudadevice") - set(opts -fc1 -xcuda) - set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod) + if(${filename} STREQUAL "__cuda_device") + set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod) + elseif(${filename} STREQUAL "cudadevice") + set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod) + endif() else() set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod) if(${filename} STREQUAL "iso_fortran_env") From 6b2a2425a9f3b89e0f327a71a5d46e698cccbff8 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 25 Jun 2025 15:20:38 -0700 Subject: [PATCH 2/4] Fix header --- flang/module/cooperative_groups.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90 index e7d19f1c65b1a..935e41fc56c1a 100644 --- a/flang/module/cooperative_groups.f90 +++ b/flang/module/cooperative_groups.f90 @@ -1,4 +1,4 @@ -!===-- module/cudedevice.f90 -----------------------------------------------===! +!===-- module/cooperative_groups.f90 ---------------------------------------===! ! ! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. ! See https://llvm.org/LICENSE.txt for license information. From 0b1e0dd1101e7001c3acc7c4cdb716303f416d83 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 25 Jun 2025 15:35:20 -0700 Subject: [PATCH 3/4] Add missing dep --- flang/tools/f18/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flang/tools/f18/CMakeLists.txt b/flang/tools/f18/CMakeLists.txt index 95fe9ad87ac45..546b6acaaf91d 100644 --- a/flang/tools/f18/CMakeLists.txt +++ b/flang/tools/f18/CMakeLists.txt @@ -69,6 +69,8 @@ if (NOT CMAKE_CROSSCOMPILING) set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod) elseif(${filename} STREQUAL "cudadevice") set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod) + elseif(${filename} STREQUAL "cooperative_groups") + set(depends ${FLANG_INTRINSIC_MODULES_DIR}/cudadevice.mod) endif() else() set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod) From 3eb70a2548f558bccd12e30cbdbc1225f609e1c6 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Wed, 25 Jun 2025 15:55:59 -0700 Subject: [PATCH 4/4] Remove line --- flang/test/Lower/CUDA/cuda-cooperative.cuf | 1 - 1 file changed, 1 deletion(-) diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf index d3deb8f318664..54523b18b20db 100644 --- a/flang/test/Lower/CUDA/cuda-cooperative.cuf +++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf @@ -50,4 +50,3 @@ end subroutine ! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref ! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref}>,size:i32,rank:i32}>>) -> !fir.ref ! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref -