Skip to content

Commit 18aeaee

Browse files
clementvalaadeshps-mcw
authored andcommitted
[flang][cuda] Add support for cluster_block_index in cooperative groups (llvm#169427)
1 parent 5df68ca commit 18aeaee

File tree

4 files changed

+73
-17
lines changed

4 files changed

+73
-17
lines changed

flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
4747
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
4848
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
4949
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
50+
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
5051
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
5152
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
5253
template <const char *fctName, int extent>

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{
368368
&CI::genNVVMTime<mlir::NVVM::Clock64Op>),
369369
{},
370370
/*isElemental=*/false},
371+
{"cluster_block_index",
372+
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
373+
&CI::genClusterBlockIndex),
374+
{},
375+
/*isElemental=*/false},
371376
{"cluster_dim_blocks",
372377
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
373378
&CI::genClusterDimBlocks),
@@ -990,6 +995,42 @@ CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
990995
.getResult(0);
991996
}
992997

998+
static void insertValueAtPos(fir::FirOpBuilder &builder, mlir::Location loc,
999+
fir::RecordType recTy, mlir::Value base,
1000+
mlir::Value dim, unsigned fieldPos) {
1001+
auto fieldName = recTy.getTypeList()[fieldPos].first;
1002+
mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
1003+
mlir::Type fieldIndexType = fir::FieldType::get(base.getContext());
1004+
mlir::Value fieldIndex =
1005+
fir::FieldIndexOp::create(builder, loc, fieldIndexType, fieldName, recTy,
1006+
/*typeParams=*/mlir::ValueRange{});
1007+
mlir::Value coord = fir::CoordinateOp::create(
1008+
builder, loc, builder.getRefType(fieldTy), base, fieldIndex);
1009+
fir::StoreOp::create(builder, loc, dim, coord);
1010+
}
1011+
1012+
// CLUSTER_BLOCK_INDEX
1013+
mlir::Value
1014+
CUDAIntrinsicLibrary::genClusterBlockIndex(mlir::Type resultType,
1015+
llvm::ArrayRef<mlir::Value> args) {
1016+
assert(args.size() == 0);
1017+
auto recTy = mlir::cast<fir::RecordType>(resultType);
1018+
assert(recTy && "RecordType expepected");
1019+
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
1020+
mlir::Type i32Ty = builder.getI32Type();
1021+
mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create(builder, loc, i32Ty);
1022+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
1023+
x = mlir::arith::AddIOp::create(builder, loc, x, one);
1024+
insertValueAtPos(builder, loc, recTy, res, x, 0);
1025+
mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create(builder, loc, i32Ty);
1026+
y = mlir::arith::AddIOp::create(builder, loc, y, one);
1027+
insertValueAtPos(builder, loc, recTy, res, y, 1);
1028+
mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create(builder, loc, i32Ty);
1029+
z = mlir::arith::AddIOp::create(builder, loc, z, one);
1030+
insertValueAtPos(builder, loc, recTy, res, z, 2);
1031+
return res;
1032+
}
1033+
9931034
// CLUSTER_DIM_BLOCKS
9941035
mlir::Value
9951036
CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
@@ -998,27 +1039,13 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
9981039
auto recTy = mlir::cast<fir::RecordType>(resultType);
9991040
assert(recTy && "RecordType expepected");
10001041
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
1001-
1002-
auto insertDim = [&](mlir::Value dim, unsigned fieldPos) {
1003-
auto fieldName = recTy.getTypeList()[fieldPos].first;
1004-
mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
1005-
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
1006-
mlir::Value fieldIndex = fir::FieldIndexOp::create(
1007-
builder, loc, fieldIndexType, fieldName, recTy,
1008-
/*typeParams=*/mlir::ValueRange{});
1009-
mlir::Value coord = fir::CoordinateOp::create(
1010-
builder, loc, builder.getRefType(fieldTy), res, fieldIndex);
1011-
fir::StoreOp::create(builder, loc, dim, coord);
1012-
};
1013-
10141042
mlir::Type i32Ty = builder.getI32Type();
10151043
mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty);
1016-
insertDim(x, 0);
1044+
insertValueAtPos(builder, loc, recTy, res, x, 0);
10171045
mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty);
1018-
insertDim(y, 1);
1046+
insertValueAtPos(builder, loc, recTy, res, y, 1);
10191047
mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty);
1020-
insertDim(z, 2);
1021-
1048+
insertValueAtPos(builder, loc, recTy, res, z, 2);
10221049
return res;
10231050
}
10241051

flang/module/cooperative_groups.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ module cooperative_groups
3838
integer(4) :: rank
3939
end type thread_group
4040

41+
interface
42+
attributes(device) function cluster_block_index()
43+
import
44+
type(dim3) :: cluster_block_index
45+
end function
46+
end interface
47+
4148
interface
4249
attributes(device) function cluster_dim_blocks()
4350
import

flang/test/Lower/CUDA/cuda-cluster.cuf

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,24 @@ end subroutine
3232
! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
3333
! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
3434
! CHECK: fir.store %[[Z]] to %[[COORD_Z]] : !fir.ref<i32>
35+
36+
attributes(global) subroutine test_cluster_block_index()
37+
use cooperative_groups
38+
type(dim3) :: blockIndex
39+
40+
blockIndex = cluster_block_index()
41+
end subroutine
42+
43+
! CHECK-LABEL: func.func @_QPtest_cluster_block_index() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
44+
! CHECK: %[[X:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.x : i32
45+
! CHECK: %[[X1:.*]] = arith.addi %[[X]], %c1{{.*}} : i32
46+
! CHECK: %[[COORD_X:.*]] = fir.coordinate_of %{{.*}}, x : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
47+
! CHECK: fir.store %[[X1]] to %[[COORD_X]] : !fir.ref<i32>
48+
! CHECK: %[[Y:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.y : i32
49+
! CHECK: %[[Y1:.*]] = arith.addi %[[Y]], %c1{{.*}} : i32
50+
! CHECK: %[[COORD_Y:.*]] = fir.coordinate_of %{{.*}}, y : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
51+
! CHECK: fir.store %[[Y1]] to %[[COORD_Y]] : !fir.ref<i32>
52+
! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.z : i32
53+
! CHECK: %[[Z1:.*]] = arith.addi %[[Z]], %c1{{.*}} : i32
54+
! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
55+
! CHECK: fir.store %[[Z1]] to %[[COORD_Z]] : !fir.ref<i32>

0 commit comments

Comments
 (0)