@@ -368,6 +368,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{
368368 &CI::genNVVMTime<mlir::NVVM::Clock64Op>),
369369 {},
370370 /* isElemental=*/ false },
371+ {" cluster_block_index" ,
372+ static_cast <CUDAIntrinsicLibrary::ElementalGenerator>(
373+ &CI::genClusterBlockIndex),
374+ {},
375+ /* isElemental=*/ false },
371376 {" cluster_dim_blocks" ,
372377 static_cast <CUDAIntrinsicLibrary::ElementalGenerator>(
373378 &CI::genClusterDimBlocks),
@@ -990,6 +995,42 @@ CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
990995 .getResult (0 );
991996}
992997
998+ static void insertValueAtPos (fir::FirOpBuilder &builder, mlir::Location loc,
999+ fir::RecordType recTy, mlir::Value base,
1000+ mlir::Value dim, unsigned fieldPos) {
1001+ auto fieldName = recTy.getTypeList ()[fieldPos].first ;
1002+ mlir::Type fieldTy = recTy.getTypeList ()[fieldPos].second ;
1003+ mlir::Type fieldIndexType = fir::FieldType::get (base.getContext ());
1004+ mlir::Value fieldIndex =
1005+ fir::FieldIndexOp::create (builder, loc, fieldIndexType, fieldName, recTy,
1006+ /* typeParams=*/ mlir::ValueRange{});
1007+ mlir::Value coord = fir::CoordinateOp::create (
1008+ builder, loc, builder.getRefType (fieldTy), base, fieldIndex);
1009+ fir::StoreOp::create (builder, loc, dim, coord);
1010+ }
1011+
1012+ // CLUSTER_BLOCK_INDEX
1013+ mlir::Value
1014+ CUDAIntrinsicLibrary::genClusterBlockIndex (mlir::Type resultType,
1015+ llvm::ArrayRef<mlir::Value> args) {
1016+ assert (args.size () == 0 );
1017+ auto recTy = mlir::cast<fir::RecordType>(resultType);
1018+ assert (recTy && " RecordType expepected" );
1019+ mlir::Value res = fir::AllocaOp::create (builder, loc, resultType);
1020+ mlir::Type i32Ty = builder.getI32Type ();
1021+ mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create (builder, loc, i32Ty);
1022+ mlir::Value one = builder.createIntegerConstant (loc, i32Ty, 1 );
1023+ x = mlir::arith::AddIOp::create (builder, loc, x, one);
1024+ insertValueAtPos (builder, loc, recTy, res, x, 0 );
1025+ mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create (builder, loc, i32Ty);
1026+ y = mlir::arith::AddIOp::create (builder, loc, y, one);
1027+ insertValueAtPos (builder, loc, recTy, res, y, 1 );
1028+ mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create (builder, loc, i32Ty);
1029+ z = mlir::arith::AddIOp::create (builder, loc, z, one);
1030+ insertValueAtPos (builder, loc, recTy, res, z, 2 );
1031+ return res;
1032+ }
1033+
9931034// CLUSTER_DIM_BLOCKS
9941035mlir::Value
9951036CUDAIntrinsicLibrary::genClusterDimBlocks (mlir::Type resultType,
@@ -998,27 +1039,13 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
9981039 auto recTy = mlir::cast<fir::RecordType>(resultType);
9991040 assert (recTy && " RecordType expepected" );
10001041 mlir::Value res = fir::AllocaOp::create (builder, loc, resultType);
1001-
1002- auto insertDim = [&](mlir::Value dim, unsigned fieldPos) {
1003- auto fieldName = recTy.getTypeList ()[fieldPos].first ;
1004- mlir::Type fieldTy = recTy.getTypeList ()[fieldPos].second ;
1005- mlir::Type fieldIndexType = fir::FieldType::get (resultType.getContext ());
1006- mlir::Value fieldIndex = fir::FieldIndexOp::create (
1007- builder, loc, fieldIndexType, fieldName, recTy,
1008- /* typeParams=*/ mlir::ValueRange{});
1009- mlir::Value coord = fir::CoordinateOp::create (
1010- builder, loc, builder.getRefType (fieldTy), res, fieldIndex);
1011- fir::StoreOp::create (builder, loc, dim, coord);
1012- };
1013-
10141042 mlir::Type i32Ty = builder.getI32Type ();
10151043 mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create (builder, loc, i32Ty);
1016- insertDim ( x, 0 );
1044+ insertValueAtPos (builder, loc, recTy, res, x, 0 );
10171045 mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create (builder, loc, i32Ty);
1018- insertDim ( y, 1 );
1046+ insertValueAtPos (builder, loc, recTy, res, y, 1 );
10191047 mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create (builder, loc, i32Ty);
1020- insertDim (z, 2 );
1021-
1048+ insertValueAtPos (builder, loc, recTy, res, z, 2 );
10221049 return res;
10231050}
10241051
0 commit comments