Skip to content

Commit b2f504f

Browse files
authored
[flang][cuda] Inline this_warp() calls (#146134)
1 parent 68239b7 commit b2f504f

File tree

4 files changed

+75
-0
lines changed

4 files changed

+75
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ struct IntrinsicLibrary {
443443
fir::ExtendedValue genTranspose(mlir::Type,
444444
llvm::ArrayRef<fir::ExtendedValue>);
445445
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
446+
mlir::Value genThisWarp(mlir::Type, llvm::ArrayRef<mlir::Value>);
446447
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
447448
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
448449
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ static constexpr IntrinsicHandler handlers[]{
933933
/*isElemental=*/false},
934934
{"tand", &I::genTand},
935935
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
936+
{"this_warp", &I::genThisWarp, {}, /*isElemental=*/false},
936937
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
937938
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
938939
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8194,6 +8195,45 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
81948195
return res;
81958196
}
81968197

8198+
// THIS_WARP
8199+
mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType,
8200+
llvm::ArrayRef<mlir::Value> args) {
8201+
assert(args.size() == 0);
8202+
auto recTy = mlir::cast<fir::RecordType>(resultType);
8203+
assert(recTy && "RecordType expepected");
8204+
mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
8205+
mlir::Type i32Ty = builder.getI32Type();
8206+
8207+
// coalesced_group%size = 32
8208+
mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32);
8209+
auto sizeFieldName = recTy.getTypeList()[1].first;
8210+
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
8211+
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
8212+
mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
8213+
loc, fieldIndexType, sizeFieldName, recTy,
8214+
/*typeParams=*/mlir::ValueRange{});
8215+
mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
8216+
loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
8217+
builder.create<fir::StoreOp>(loc, size, sizeCoord);
8218+
8219+
// coalesced_group%rank = threadIdx.x & 31 + 1
8220+
mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
8221+
mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31);
8222+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
8223+
mlir::Value masked =
8224+
builder.create<mlir::arith::AndIOp>(loc, threadIdX, mask);
8225+
mlir::Value rank = builder.create<mlir::arith::AddIOp>(loc, masked, one);
8226+
auto rankFieldName = recTy.getTypeList()[2].first;
8227+
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
8228+
mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
8229+
loc, fieldIndexType, rankFieldName, recTy,
8230+
/*typeParams=*/mlir::ValueRange{});
8231+
mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
8232+
loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
8233+
builder.create<fir::StoreOp>(loc, rank, rankCoord);
8234+
return res;
8235+
}
8236+
81978237
// TRAILZ
81988238
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
81998239
llvm::ArrayRef<mlir::Value> args) {

flang/module/cooperative_groups.f90

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,24 @@ module cooperative_groups
2020
integer(4) :: rank
2121
end type grid_group
2222

23+
type :: coalesced_group
24+
type(c_devptr), private :: handle
25+
integer(4) :: size
26+
integer(4) :: rank
27+
end type coalesced_group
28+
2329
interface
2430
attributes(device) function this_grid()
2531
import
2632
type(grid_group) :: this_grid
2733
end function
2834
end interface
2935

36+
interface this_warp
37+
attributes(device) function this_warp()
38+
import
39+
type(coalesced_group) :: this_warp
40+
end function
41+
end interface
42+
3043
end module

flang/test/Lower/CUDA/cuda-cooperative.cuf

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,24 @@ end subroutine
5050
! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
5151
! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
5252
! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
53+
54+
attributes(grid_global) subroutine w1()
55+
use cooperative_groups
56+
type(coalesced_group) :: gg
57+
gg = this_warp()
58+
end subroutine
59+
60+
! CHECK: %[[WARPSIZE:.*]] = fir.alloca i32 {bindc_name = "__builtin_warpsize", uniq_name = "_QM__fortran_builtinsEC__builtin_warpsize"}
61+
! CHECK: %[[WARPSIZE_DECL:.*]]:2 = hlfir.declare %[[WARPSIZE]] {uniq_name = "_QM__fortran_builtinsEC__builtin_warpsize"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
62+
! CHECK: %[[COALESCED_GROUP:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
63+
! CHECK: %[[C32:.*]] = arith.constant 32 : i32
64+
! CHECK: %[[SIZE_COORD:.*]] = fir.coordinate_of %[[COALESCED_GROUP]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
65+
! CHECK: fir.store %[[C32]] to %[[SIZE_COORD]] : !fir.ref<i32>
66+
67+
! CHECK: %[[THREAD_ID:.*]] = nvvm.read.ptx.sreg.tid.x : i32
68+
! CHECK: %[[C31:.*]] = arith.constant 31 : i32
69+
! CHECK: %[[C1:.*]] = arith.constant 1 : i32
70+
! CHECK: %[[AND:.*]] = arith.andi %[[THREAD_ID]], %[[C31]] : i32
71+
! CHECK: %[[RANK:.*]] = arith.addi %[[AND]], %[[C1]] : i32
72+
! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %{{.*}}, rank : (!fir.ref<!fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
73+
! CHECK: fir.store %[[RANK]] to %[[RANK_COORD]] : !fir.ref<i32>

0 commit comments

Comments
 (0)