Skip to content

Commit ab5ae9a

Browse files
authored
[flang][cuda] Implement this_cluster for cooperative groups (#169414)
Implement `this_cluster` like `this_group` by lowering it directly like an intrinsic function. Use the NVVM operation to get the rank and size information and populate the derived type.
1 parent a508249 commit ab5ae9a

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
6060
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
6161
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
6262
void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
63+
mlir::Value genThisCluster(mlir::Type, llvm::ArrayRef<mlir::Value>);
6364
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
6465
mlir::Value genThisThreadBlock(mlir::Type, llvm::ArrayRef<mlir::Value>);
6566
mlir::Value genThisWarp(mlir::Type, llvm::ArrayRef<mlir::Value>);

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ static constexpr IntrinsicHandler cudaHandlers[]{
457457
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genSyncWarp),
458458
{},
459459
/*isElemental=*/false},
460+
{"this_cluster",
461+
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisCluster),
462+
{},
463+
/*isElemental=*/false},
460464
{"this_grid",
461465
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisGrid),
462466
{},
@@ -1122,6 +1126,44 @@ void CUDAIntrinsicLibrary::genSyncWarp(
11221126
mlir::NVVM::SyncWarpOp::create(builder, loc, fir::getBase(args[0]));
11231127
}
11241128

1129+
// THIS_CLUSTER
1130+
mlir::Value
1131+
CUDAIntrinsicLibrary::genThisCluster(mlir::Type resultType,
1132+
llvm::ArrayRef<mlir::Value> args) {
1133+
assert(args.size() == 0);
1134+
auto recTy = mlir::cast<fir::RecordType>(resultType);
1135+
assert(recTy && "RecordType expepected");
1136+
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
1137+
mlir::Type i32Ty = builder.getI32Type();
1138+
1139+
// SIZE
1140+
mlir::Value size = mlir::NVVM::ClusterDim::create(builder, loc, i32Ty);
1141+
auto sizeFieldName = recTy.getTypeList()[1].first;
1142+
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
1143+
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
1144+
mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
1145+
builder, loc, fieldIndexType, sizeFieldName, recTy,
1146+
/*typeParams=*/mlir::ValueRange{});
1147+
mlir::Value sizeCoord = fir::CoordinateOp::create(
1148+
builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
1149+
fir::StoreOp::create(builder, loc, size, sizeCoord);
1150+
1151+
// RANK
1152+
mlir::Value rank = mlir::NVVM::ClusterId::create(builder, loc, i32Ty);
1153+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
1154+
rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
1155+
auto rankFieldName = recTy.getTypeList()[2].first;
1156+
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
1157+
mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
1158+
builder, loc, fieldIndexType, rankFieldName, recTy,
1159+
/*typeParams=*/mlir::ValueRange{});
1160+
mlir::Value rankCoord = fir::CoordinateOp::create(
1161+
builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
1162+
fir::StoreOp::create(builder, loc, rank, rankCoord);
1163+
1164+
return res;
1165+
}
1166+
11251167
// THIS_GRID
11261168
mlir::Value
11271169
CUDAIntrinsicLibrary::genThisGrid(mlir::Type resultType,

flang/module/cooperative_groups.f90

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ module cooperative_groups
1414

1515
implicit none
1616

17+
type :: cluster_group
18+
type(c_devptr), private :: handle
19+
integer(4) :: size
20+
integer(4) :: rank
21+
end type cluster_group
22+
1723
type :: grid_group
1824
type(c_devptr), private :: handle
1925
integer(4) :: size
@@ -32,6 +38,13 @@ module cooperative_groups
3238
integer(4) :: rank
3339
end type thread_group
3440

41+
interface
42+
attributes(device) function this_cluster()
43+
import
44+
type(cluster_group) :: this_cluster
45+
end function
46+
end interface
47+
3548
interface
3649
attributes(device) function this_grid()
3750
import

0 commit comments

Comments
 (0)