Skip to content

Commit 9d040f9

Browse files
clementvalgithub-actions[bot]
authored andcommitted
Automerge: [flang][cuda] Add support for f16 atomicadd (#166229)
2 parents 5e097b1 + ac21fde commit 9d040f9

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ struct IntrinsicLibrary {
188188
fir::ExtendedValue genAny(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
189189
mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
190190
mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
191+
fir::ExtendedValue genAtomicAddR2(mlir::Type,
192+
llvm::ArrayRef<fir::ExtendedValue>);
191193
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
192194
fir::ExtendedValue genAtomicCas(mlir::Type,
193195
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ static constexpr IntrinsicHandler handlers[]{
294294
{"atomicaddf", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
295295
{"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
296296
{"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
297+
{"atomicaddr2",
298+
&I::genAtomicAddR2,
299+
{{{"a", asAddr}, {"v", asAddr}}},
300+
false},
297301
{"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
298302
{"atomiccasd",
299303
&I::genAtomicCas,
@@ -3119,14 +3123,51 @@ static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc,
31193123
mlir::Value IntrinsicLibrary::genAtomicAdd(mlir::Type resultType,
31203124
llvm::ArrayRef<mlir::Value> args) {
31213125
assert(args.size() == 2);
3122-
31233126
mlir::LLVM::AtomicBinOp binOp =
31243127
mlir::isa<mlir::IntegerType>(args[1].getType())
31253128
? mlir::LLVM::AtomicBinOp::add
31263129
: mlir::LLVM::AtomicBinOp::fadd;
31273130
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
31283131
}
31293132

3133+
fir::ExtendedValue
3134+
IntrinsicLibrary::genAtomicAddR2(mlir::Type resultType,
3135+
llvm::ArrayRef<fir::ExtendedValue> args) {
3136+
assert(args.size() == 2);
3137+
3138+
mlir::Value a = fir::getBase(args[0]);
3139+
3140+
if (mlir::isa<fir::BaseBoxType>(a.getType())) {
3141+
a = fir::BoxAddrOp::create(builder, loc, a);
3142+
}
3143+
3144+
auto loc = builder.getUnknownLoc();
3145+
auto f16Ty = builder.getF16Type();
3146+
auto i32Ty = builder.getI32Type();
3147+
auto vecF16Ty = mlir::VectorType::get({2}, f16Ty);
3148+
mlir::Type idxTy = builder.getIndexType();
3149+
auto f16RefTy = fir::ReferenceType::get(f16Ty);
3150+
auto zero = builder.createIntegerConstant(loc, idxTy, 0);
3151+
auto one = builder.createIntegerConstant(loc, idxTy, 1);
3152+
auto v1Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
3153+
fir::getBase(args[1]), zero);
3154+
auto v2Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
3155+
fir::getBase(args[1]), one);
3156+
auto v1 = fir::LoadOp::create(builder, loc, v1Coord);
3157+
auto v2 = fir::LoadOp::create(builder, loc, v2Coord);
3158+
mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF16Ty);
3159+
mlir::Value vec1 = mlir::LLVM::InsertElementOp::create(
3160+
builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0));
3161+
mlir::Value vec2 = mlir::LLVM::InsertElementOp::create(
3162+
builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1));
3163+
auto res = genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, vec2);
3164+
auto i32VecTy = mlir::VectorType::get({1}, i32Ty);
3165+
mlir::Value vecI32 =
3166+
mlir::vector::BitCastOp::create(builder, loc, i32VecTy, res);
3167+
return mlir::vector::ExtractOp::create(builder, loc, vecI32,
3168+
mlir::ArrayRef<int64_t>{0});
3169+
}
3170+
31303171
mlir::Value IntrinsicLibrary::genAtomicSub(mlir::Type resultType,
31313172
llvm::ArrayRef<mlir::Value> args) {
31323173
assert(args.size() == 2);

flang/module/cudadevice.f90

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,11 @@ attributes(device) pure integer(8) function atomicaddl(address, val)
11711171
integer(8), intent(inout) :: address
11721172
integer(8), value :: val
11731173
end function
1174+
attributes(device) pure integer(4) function atomicaddr2(address, val)
1175+
!dir$ ignore_tkr (rd) address, (d) val
1176+
real(2), dimension(2), intent(inout) :: address
1177+
real(2), dimension(2), intent(in) :: val
1178+
end function
11741179
end interface
11751180

11761181
interface atomicsub

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ attributes(global) subroutine devsub()
1414
integer :: smalltime
1515
integer(4) :: res, offset
1616
integer(8) :: resl
17+
real(2) :: r2a(2)
18+
real(2) :: tmp2(2)
1719

1820
integer :: tid
1921
tid = threadIdx%x
@@ -34,6 +36,7 @@ attributes(global) subroutine devsub()
3436
al = atomicadd(al, 1_8)
3537
af = atomicadd(af, 1.0_4)
3638
ad = atomicadd(ad, 1.0_8)
39+
ai = atomicadd(r2a, tmp2)
3740

3841
ai = atomicsub(ai, 1_4)
3942
al = atomicsub(al, 1_8)
@@ -128,6 +131,7 @@ end
128131
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
129132
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
130133
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f64
134+
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf16>
131135

132136
! CHECK: %{{.*}} = llvm.atomicrmw sub %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
133137
! CHECK: %{{.*}} = llvm.atomicrmw sub %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64

0 commit comments

Comments
 (0)