Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
template <const char *fctName, int extent>
Expand Down
61 changes: 44 additions & 17 deletions flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{
&CI::genNVVMTime<mlir::NVVM::Clock64Op>),
{},
/*isElemental=*/false},
{"cluster_block_index",
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
&CI::genClusterBlockIndex),
{},
/*isElemental=*/false},
{"cluster_dim_blocks",
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
&CI::genClusterDimBlocks),
Expand Down Expand Up @@ -990,6 +995,42 @@ CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
.getResult(0);
}

static void insertValueAtPos(fir::FirOpBuilder &builder, mlir::Location loc,
fir::RecordType recTy, mlir::Value base,
mlir::Value dim, unsigned fieldPos) {
auto fieldName = recTy.getTypeList()[fieldPos].first;
mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
mlir::Type fieldIndexType = fir::FieldType::get(base.getContext());
mlir::Value fieldIndex =
fir::FieldIndexOp::create(builder, loc, fieldIndexType, fieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
mlir::Value coord = fir::CoordinateOp::create(
builder, loc, builder.getRefType(fieldTy), base, fieldIndex);
fir::StoreOp::create(builder, loc, dim, coord);
}

// CLUSTER_BLOCK_INDEX
mlir::Value
CUDAIntrinsicLibrary::genClusterBlockIndex(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0);
auto recTy = mlir::cast<fir::RecordType>(resultType);
assert(recTy && "RecordType expepected");
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
mlir::Type i32Ty = builder.getI32Type();
mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create(builder, loc, i32Ty);
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
x = mlir::arith::AddIOp::create(builder, loc, x, one);
insertValueAtPos(builder, loc, recTy, res, x, 0);
mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create(builder, loc, i32Ty);
y = mlir::arith::AddIOp::create(builder, loc, y, one);
insertValueAtPos(builder, loc, recTy, res, y, 1);
mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create(builder, loc, i32Ty);
z = mlir::arith::AddIOp::create(builder, loc, z, one);
insertValueAtPos(builder, loc, recTy, res, z, 2);
return res;
}

// CLUSTER_DIM_BLOCKS
mlir::Value
CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
Expand All @@ -998,27 +1039,13 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
auto recTy = mlir::cast<fir::RecordType>(resultType);
assert(recTy && "RecordType expepected");
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);

auto insertDim = [&](mlir::Value dim, unsigned fieldPos) {
auto fieldName = recTy.getTypeList()[fieldPos].first;
mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
mlir::Value fieldIndex = fir::FieldIndexOp::create(
builder, loc, fieldIndexType, fieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
mlir::Value coord = fir::CoordinateOp::create(
builder, loc, builder.getRefType(fieldTy), res, fieldIndex);
fir::StoreOp::create(builder, loc, dim, coord);
};

mlir::Type i32Ty = builder.getI32Type();
mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty);
insertDim(x, 0);
insertValueAtPos(builder, loc, recTy, res, x, 0);
mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty);
insertDim(y, 1);
insertValueAtPos(builder, loc, recTy, res, y, 1);
mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty);
insertDim(z, 2);

insertValueAtPos(builder, loc, recTy, res, z, 2);
return res;
}

Expand Down
7 changes: 7 additions & 0 deletions flang/module/cooperative_groups.f90
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ module cooperative_groups
integer(4) :: rank
end type thread_group

interface
attributes(device) function cluster_block_index()
import
type(dim3) :: cluster_block_index
end function
end interface

interface
attributes(device) function cluster_dim_blocks()
import
Expand Down
21 changes: 21 additions & 0 deletions flang/test/Lower/CUDA/cuda-cluster.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,24 @@ end subroutine
! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[Z]] to %[[COORD_Z]] : !fir.ref<i32>

attributes(global) subroutine test_cluster_block_index()
use cooperative_groups
type(dim3) :: blockIndex

blockIndex = cluster_block_index()
end subroutine

! CHECK-LABEL: func.func @_QPtest_cluster_block_index() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: %[[X:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.x : i32
! CHECK: %[[X1:.*]] = arith.addi %[[X]], %c1{{.*}} : i32
! CHECK: %[[COORD_X:.*]] = fir.coordinate_of %{{.*}}, x : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[X1]] to %[[COORD_X]] : !fir.ref<i32>
! CHECK: %[[Y:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.y : i32
! CHECK: %[[Y1:.*]] = arith.addi %[[Y]], %c1{{.*}} : i32
! CHECK: %[[COORD_Y:.*]] = fir.coordinate_of %{{.*}}, y : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[Y1]] to %[[COORD_Y]] : !fir.ref<i32>
! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.z : i32
! CHECK: %[[Z1:.*]] = arith.addi %[[Z]], %c1{{.*}} : i32
! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[Z1]] to %[[COORD_Z]] : !fir.ref<i32>