Skip to content

Commit 0ce03c2

Browse files
authored
[flang][cuda] Add interface and lowering for atomicadd_r4x2 and atomicadd_r4x4 (llvm#166308)
1 parent 67ce4ab commit 0ce03c2

File tree

4 files changed

+83
-33
lines changed

4 files changed

+83
-33
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ struct IntrinsicLibrary {
190190
mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
191191
fir::ExtendedValue genAtomicAddR2(mlir::Type,
192192
llvm::ArrayRef<fir::ExtendedValue>);
193+
template <int extent>
193194
fir::ExtendedValue genAtomicAddVector(mlir::Type,
194195
llvm::ArrayRef<fir::ExtendedValue>);
195196
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,12 @@ static constexpr IntrinsicHandler handlers[]{
290290
{"atan2pi", &I::genAtanpi},
291291
{"atand", &I::genAtand},
292292
{"atanpi", &I::genAtanpi},
293-
{"atomicadd_r2x2",
294-
&I::genAtomicAddVector,
293+
{"atomicadd_r4x2",
294+
&I::genAtomicAddVector<2>,
295295
{{{"a", asAddr}, {"v", asAddr}}},
296296
false},
297-
{"atomicadd_r4x2",
298-
&I::genAtomicAddVector,
297+
{"atomicadd_r4x4",
298+
&I::genAtomicAddVector<4>,
299299
{{{"a", asAddr}, {"v", asAddr}}},
300300
false},
301301
{"atomicaddd", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
@@ -306,6 +306,14 @@ static constexpr IntrinsicHandler handlers[]{
306306
&I::genAtomicAddR2,
307307
{{{"a", asAddr}, {"v", asAddr}}},
308308
false},
309+
{"atomicaddvector_r2x2",
310+
&I::genAtomicAddVector<2>,
311+
{{{"a", asAddr}, {"v", asAddr}}},
312+
false},
313+
{"atomicaddvector_r4x2",
314+
&I::genAtomicAddVector<2>,
315+
{{{"a", asAddr}, {"v", asAddr}}},
316+
false},
309317
{"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
310318
{"atomiccasd",
311319
&I::genAtomicCas,
@@ -3176,44 +3184,51 @@ IntrinsicLibrary::genAtomicAddR2(mlir::Type resultType,
31763184
mlir::ArrayRef<int64_t>{0});
31773185
}
31783186

3187+
template <int extent>
31793188
fir::ExtendedValue
31803189
IntrinsicLibrary::genAtomicAddVector(mlir::Type resultType,
31813190
llvm::ArrayRef<fir::ExtendedValue> args) {
31823191
assert(args.size() == 2);
31833192
mlir::Value res = fir::AllocaOp::create(
3184-
builder, loc, fir::SequenceType::get({2}, resultType));
3193+
builder, loc, fir::SequenceType::get({extent}, resultType));
31853194
mlir::Value a = fir::getBase(args[0]);
31863195
if (mlir::isa<fir::BaseBoxType>(a.getType())) {
31873196
a = fir::BoxAddrOp::create(builder, loc, a);
31883197
}
3189-
auto vecTy = mlir::VectorType::get({2}, resultType);
3198+
auto vecTy = mlir::VectorType::get({extent}, resultType);
31903199
auto refTy = fir::ReferenceType::get(resultType);
31913200
mlir::Type i32Ty = builder.getI32Type();
31923201
mlir::Type idxTy = builder.getIndexType();
3193-
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
3194-
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
3195-
mlir::Value v1Coord = fir::CoordinateOp::create(builder, loc, refTy,
3196-
fir::getBase(args[1]), zero);
3197-
mlir::Value v2Coord = fir::CoordinateOp::create(builder, loc, refTy,
3198-
fir::getBase(args[1]), one);
3199-
mlir::Value v1 = fir::LoadOp::create(builder, loc, v1Coord);
3200-
mlir::Value v2 = fir::LoadOp::create(builder, loc, v2Coord);
3202+
3203+
// Extract the values from the array.
3204+
llvm::SmallVector<mlir::Value> values;
3205+
for (unsigned i = 0; i < extent; ++i) {
3206+
mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i);
3207+
mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy,
3208+
fir::getBase(args[1]), pos);
3209+
mlir::Value value = fir::LoadOp::create(builder, loc, coord);
3210+
values.push_back(value);
3211+
}
3212+
// Pack extracted values into a vector to call the atomic add.
32013213
mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecTy);
3202-
mlir::Value vec1 = mlir::LLVM::InsertElementOp::create(
3203-
builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0));
3204-
mlir::Value vec2 = mlir::LLVM::InsertElementOp::create(
3205-
builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1));
3214+
for (unsigned i = 0; i < extent; ++i) {
3215+
mlir::Value insert = mlir::LLVM::InsertElementOp::create(
3216+
builder, loc, undef, values[i],
3217+
builder.createIntegerConstant(loc, i32Ty, i));
3218+
undef = insert;
3219+
}
3220+
// Atomic operation with a vector of values.
32063221
mlir::Value add =
3207-
genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, vec2);
3208-
mlir::Value r1 = mlir::LLVM::ExtractElementOp::create(
3209-
builder, loc, add, builder.createIntegerConstant(loc, i32Ty, 0));
3210-
mlir::Value r2 = mlir::LLVM::ExtractElementOp::create(
3211-
builder, loc, add, builder.createIntegerConstant(loc, i32Ty, 1));
3212-
mlir::Value c1 = fir::CoordinateOp::create(builder, loc, refTy, res, zero);
3213-
mlir::Value c2 = fir::CoordinateOp::create(builder, loc, refTy, res, one);
3214-
fir::StoreOp::create(builder, loc, r1, c1);
3215-
fir::StoreOp::create(builder, loc, r2, c2);
3216-
mlir::Value ext = builder.createIntegerConstant(loc, idxTy, 2);
3222+
genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, undef);
3223+
// Store results in the result array.
3224+
for (unsigned i = 0; i < extent; ++i) {
3225+
mlir::Value r = mlir::LLVM::ExtractElementOp::create(
3226+
builder, loc, add, builder.createIntegerConstant(loc, i32Ty, i));
3227+
mlir::Value c = fir::CoordinateOp::create(
3228+
builder, loc, refTy, res, builder.createIntegerConstant(loc, idxTy, i));
3229+
fir::StoreOp::create(builder, loc, r, c);
3230+
}
3231+
mlir::Value ext = builder.createIntegerConstant(loc, idxTy, extent);
32173232
return fir::ArrayBoxValue(res, {ext});
32183233
}
32193234

flang/module/cudadevice.f90

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,13 +1179,22 @@ attributes(device) pure integer(4) function atomicaddr2(address, val)
11791179
end interface
11801180

11811181
interface atomicaddvector
1182-
attributes(device) pure function atomicadd_r2x2(address, val) result(z)
1182+
attributes(device) pure function atomicaddvector_r2x2(address, val) result(z)
11831183
!dir$ ignore_tkr (rd) address, (d) val
11841184
real(2), dimension(2), intent(inout) :: address
11851185
real(2), dimension(2), intent(in) :: val
11861186
real(2), dimension(2) :: z
11871187
end function
11881188

1189+
attributes(device) pure function atomicaddvector_r4x2(address, val) result(z)
1190+
!dir$ ignore_tkr (rd) address, (d) val
1191+
real(4), dimension(2), intent(inout) :: address
1192+
real(4), dimension(2), intent(in) :: val
1193+
real(4), dimension(2) :: z
1194+
end function
1195+
end interface
1196+
1197+
interface atomicaddreal4x2
11891198
attributes(device) pure function atomicadd_r4x2(address, val) result(z)
11901199
!dir$ ignore_tkr (rd) address, (d) val
11911200
real(4), dimension(2), intent(inout) :: address
@@ -1194,6 +1203,15 @@ attributes(device) pure function atomicadd_r4x2(address, val) result(z)
11941203
end function
11951204
end interface
11961205

1206+
interface atomicaddreal4x4
1207+
attributes(device) pure function atomicadd_r4x4(address, val) result(z)
1208+
!dir$ ignore_tkr (rd) address, (d) val
1209+
real(4), dimension(4), intent(inout) :: address
1210+
real(4), dimension(4), intent(in) :: val
1211+
real(4), dimension(4) :: z
1212+
end function
1213+
end interface
1214+
11971215
interface atomicsub
11981216
attributes(device) pure integer function atomicsubi(address, val)
11991217
!dir$ ignore_tkr (d) address, (d) val

flang/test/Lower/CUDA/cuda-atomicadd.cuf

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,34 @@
22

33
! Test CUDA Fortran atmoicadd functions available cudadevice module
44

5-
attributes(global) subroutine atomicaddvector_r2()
5+
attributes(global) subroutine test_atomicaddvector_r2()
66
real(2), device :: a(2), tmp1(2), tmp2(2)
77
tmp1 = atomicAddVector(a, tmp2)
88
end subroutine
99

10-
! CHECK-LABEL: func.func @_QPatomicaddvector_r2() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
10+
! CHECK-LABEL: func.func @_QPtest_atomicaddvector_r2() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
1111
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf16>
1212

13-
attributes(global) subroutine atomicaddvector_r4()
13+
attributes(global) subroutine test_atomicaddvector_r4()
1414
real(4), device :: a(2), tmp1(2), tmp2(2)
1515
tmp1 = atomicAddVector(a, tmp2)
1616
end subroutine
1717

18-
! CHECK-LABEL: func.func @_QPatomicaddvector_r4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
18+
! CHECK-LABEL: func.func @_QPtest_atomicaddvector_r4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
1919
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf32>
20+
21+
attributes(global) subroutine test_atomicadd_r2x4()
22+
real(4), device :: a(2), tmp1(2), tmp2(2)
23+
tmp1 = atomicaddreal4x2(a, tmp2)
24+
end subroutine
25+
26+
! CHECK-LABEL: func.func @_QPtest_atomicadd_r2x4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
27+
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<2xf32>
28+
29+
attributes(global) subroutine test_atomicadd_r4x4()
30+
real(4), device :: a(4), tmp1(4), tmp2(4)
31+
tmp1 = atomicaddreal4x4(a, tmp2)
32+
end subroutine
33+
34+
! CHECK-LABEL: func.func @_QPtest_atomicadd_r4x4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
35+
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<4xf32>

0 commit comments

Comments
 (0)