Skip to content

Commit 4bd6560

Browse files
clementvalaadeshps-mcw
authored andcommitted
[flang][cuda] Add support for cluster_dim_blocks in cooperative_groups (llvm#169417)
1 parent 13af52c commit 4bd6560

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
4747
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
4848
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
4949
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
50+
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
5051
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
5152
template <const char *fctName, int extent>
5253
fir::ExtendedValue genLDXXFunc(mlir::Type,

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{
368368
&CI::genNVVMTime<mlir::NVVM::Clock64Op>),
369369
{},
370370
/*isElemental=*/false},
371+
{"cluster_dim_blocks",
372+
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
373+
&CI::genClusterDimBlocks),
374+
{},
375+
/*isElemental=*/false},
371376
{"fence_proxy_async",
372377
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
373378
&CI::genFenceProxyAsync),
@@ -985,6 +990,38 @@ CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
985990
.getResult(0);
986991
}
987992

993+
// CLUSTER_DIM_BLOCKS
994+
mlir::Value
995+
CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
996+
llvm::ArrayRef<mlir::Value> args) {
997+
assert(args.size() == 0);
998+
auto recTy = mlir::cast<fir::RecordType>(resultType);
999+
assert(recTy && "RecordType expepected");
1000+
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
1001+
1002+
auto insertDim = [&](mlir::Value dim, unsigned fieldPos) {
1003+
auto fieldName = recTy.getTypeList()[fieldPos].first;
1004+
mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
1005+
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
1006+
mlir::Value fieldIndex = fir::FieldIndexOp::create(
1007+
builder, loc, fieldIndexType, fieldName, recTy,
1008+
/*typeParams=*/mlir::ValueRange{});
1009+
mlir::Value coord = fir::CoordinateOp::create(
1010+
builder, loc, builder.getRefType(fieldTy), res, fieldIndex);
1011+
fir::StoreOp::create(builder, loc, dim, coord);
1012+
};
1013+
1014+
mlir::Type i32Ty = builder.getI32Type();
1015+
mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty);
1016+
insertDim(x, 0);
1017+
mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty);
1018+
insertDim(y, 1);
1019+
mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty);
1020+
insertDim(z, 2);
1021+
1022+
return res;
1023+
}
1024+
9881025
// FENCE_PROXY_ASYNC
9891026
void CUDAIntrinsicLibrary::genFenceProxyAsync(
9901027
llvm::ArrayRef<fir::ExtendedValue> args) {

flang/module/cooperative_groups.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ module cooperative_groups
3838
integer(4) :: rank
3939
end type thread_group
4040

41+
interface
42+
attributes(device) function cluster_dim_blocks()
43+
import
44+
type(dim3) :: cluster_dim_blocks
45+
end function
46+
end interface
47+
4148
interface
4249
attributes(device) function this_cluster()
4350
import
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
attributes(global) subroutine test_this_cluster()
4+
use cooperative_groups
5+
type(cluster_group) :: cluster
6+
7+
cluster = this_cluster()
8+
end subroutine
9+
10+
! CHECK-LABEL: func.func @_QPtest_this_cluster() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
11+
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTcluster_group
12+
! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTcluster_group{_QMcooperative_groupsTcluster_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
13+
! CHECK: %[[RANK:.*]] = nvvm.read.ptx.sreg.cluster.ctarank : i32
14+
! CHECK: %[[RANK_1:.*]] = arith.addi %[[RANK]], %c1{{.*}} : i32
15+
! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTcluster_group{_QMcooperative_groupsTcluster_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
16+
! CHECK: fir.store %[[RANK_1]] to %[[RANK_COORD]] : !fir.ref<i32>
17+
18+
attributes(global) subroutine test_cluster_dim_blocks()
19+
use cooperative_groups
20+
type(dim3) :: clusterDim
21+
22+
clusterDim = cluster_dim_blocks()
23+
end subroutine
24+
25+
! CHECK-LABEL: func.func @_QPtest_cluster_dim_blocks() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
26+
! CHECK: %[[X:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.x : i32
27+
! CHECK: %[[COORD_X:.*]] = fir.coordinate_of %{{.*}}, x : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
28+
! CHECK: fir.store %[[X]] to %[[COORD_X]] : !fir.ref<i32>
29+
! CHECK: %[[Y:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.y : i32
30+
! CHECK: %[[COORD_Y:.*]] = fir.coordinate_of %{{.*}}, y : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
31+
! CHECK: fir.store %[[Y]] to %[[COORD_Y]] : !fir.ref<i32>
32+
! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
33+
! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
34+
! CHECK: fir.store %[[Z]] to %[[COORD_Z]] : !fir.ref<i32>

0 commit comments

Comments
 (0)