-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[flang][cuda] Inline this_grid call for cooperative groups #145796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145796.diff 5 Files Affected:
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 17052113859e1..3cf7a4920ed7d 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -442,6 +442,7 @@ struct IntrinsicLibrary {
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
+ mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 178b6770d6b53..a00dc9f5d30a2 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
{{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
/*isElemental=*/false},
{"tand", &I::genTand},
+ {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
}
+// THIS_GRID
+mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
+ mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
+ mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
+
+ mlir::Value blockIdX = builder.create<mlir::NVVM::BlockIdXOp>(loc, i32Ty);
+ mlir::Value blockIdY = builder.create<mlir::NVVM::BlockIdYOp>(loc, i32Ty);
+ mlir::Value blockIdZ = builder.create<mlir::NVVM::BlockIdZOp>(loc, i32Ty);
+
+ mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
+ mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
+ mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
+ mlir::Value gridDimX = builder.create<mlir::NVVM::GridDimXOp>(loc, i32Ty);
+ mlir::Value gridDimY = builder.create<mlir::NVVM::GridDimYOp>(loc, i32Ty);
+ mlir::Value gridDimZ = builder.create<mlir::NVVM::GridDimZOp>(loc, i32Ty);
+
+ // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
+ // (blockDim.x * gridDim.x);
+ mlir::Value resZ =
+ builder.create<mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
+ mlir::Value resY =
+ builder.create<mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
+ mlir::Value resX =
+ builder.create<mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
+ mlir::Value resZY = builder.create<mlir::arith::MulIOp>(loc, resZ, resY);
+ mlir::Value size = builder.create<mlir::arith::MulIOp>(loc, resZY, resX);
+
+ // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
+ // blockIdx.x;
+ // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
+ // ((threadIdx.z * blockDim.y) * blockDim.x) +
+ // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
+ mlir::Value r1 = builder.create<mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
+ mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, gridDimX);
+ mlir::Value r3 = builder.create<mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
+ mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
+ mlir::Value tmp = builder.create<mlir::arith::AddIOp>(loc, r2r3, blockIdX);
+
+ mlir::Value bXbY =
+ builder.create<mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
+ mlir::Value bXbYbZ =
+ builder.create<mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
+ mlir::Value tZbY =
+ builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
+ mlir::Value tZbYbX =
+ builder.create<mlir::arith::MulIOp>(loc, tZbY, blockDimX);
+ mlir::Value tYbX =
+ builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
+ mlir::Value rank = builder.create<mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
+ rank = builder.create<mlir::arith::AddIOp>(loc, rank, tZbYbX);
+ rank = builder.create<mlir::arith::AddIOp>(loc, rank, tYbX);
+ rank = builder.create<mlir::arith::AddIOp>(loc, rank, threadIdX);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
+
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
+ loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
+ loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ builder.create<fir::StoreOp>(loc, size, sizeCoord);
+
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
+ loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
+ loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ builder.create<fir::StoreOp>(loc, rank, rankCoord);
+ return res;
+}
+
// TRAILZ
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90
new file mode 100644
index 0000000000000..e7d19f1c65b1a
--- /dev/null
+++ b/flang/module/cooperative_groups.f90
@@ -0,0 +1,30 @@
+!===-- module/cudedevice.f90 -----------------------------------------------===!
+!
+! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+! See https://llvm.org/LICENSE.txt for license information.
+! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+!
+!===------------------------------------------------------------------------===!
+
+! CUDA Fortran cooperative groups
+
+module cooperative_groups
+
+use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr
+
+implicit none
+
+type :: grid_group
+ type(c_devptr), private :: handle
+ integer(4) :: size
+ integer(4) :: rank
+end type grid_group
+
+interface
+ attributes(device) function this_grid()
+ import
+ type(grid_group) :: this_grid
+ end function
+end interface
+
+end module
diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf
new file mode 100644
index 0000000000000..d3deb8f318664
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf
@@ -0,0 +1,53 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Test CUDA Fortran procedures available in cooperative_groups module.
+
+attributes(grid_global) subroutine g1()
+ use cooperative_groups
+ type(grid_group) :: gg
+ gg = this_grid()
+end subroutine
+
+! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
+! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
+! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
+! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32
+! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32
+! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
+! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
+! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
+! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32
+! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32
+! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32
+! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32
+! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32
+! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32
+! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32
+! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32
+! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32
+! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32
+! CHECK: %[[ONE:.*]] = arith.constant 1 : i32
+! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32
+! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
+! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
+
diff --git a/flang/tools/f18/CMakeLists.txt b/flang/tools/f18/CMakeLists.txt
index fb5510d7163d1..95fe9ad87ac45 100644
--- a/flang/tools/f18/CMakeLists.txt
+++ b/flang/tools/f18/CMakeLists.txt
@@ -15,6 +15,7 @@ set(MODULES
"mma"
"__cuda_builtins"
"__cuda_device"
+ "cooperative_groups"
"cudadevice"
"ieee_arithmetic"
"ieee_exceptions"
@@ -60,12 +61,15 @@ if (NOT CMAKE_CROSSCOMPILING)
elseif(${filename} STREQUAL "__ppc_intrinsics" OR
${filename} STREQUAL "mma")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
- elseif(${filename} STREQUAL "__cuda_device")
+ elseif(${filename} STREQUAL "__cuda_device" OR
+ ${filename} STREQUAL "cudadevice" OR
+ ${filename} STREQUAL "cooperative_groups")
set(opts -fc1 -xcuda)
- set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
- elseif(${filename} STREQUAL "cudadevice")
- set(opts -fc1 -xcuda)
- set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
+ if(${filename} STREQUAL "__cuda_device")
+ set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
+ elseif(${filename} STREQUAL "cudadevice")
+ set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
+ endif()
else()
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)
if(${filename} STREQUAL "iso_fortran_env")
|
|
LGTM |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/13947 Here is the relevant piece of the build log for the reference |
No description provided.