Skip to content

Commit 9695a91

Browse files
committed
[flang][cuda] Use libdevice for atomicAdd with 4xf32
1 parent 074d17e commit 9695a91

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
2929
template <int extent>
3030
fir::ExtendedValue genAtomicAddVector(mlir::Type,
3131
llvm::ArrayRef<fir::ExtendedValue>);
32+
fir::ExtendedValue genAtomicAddVector4x4(mlir::Type,
33+
llvm::ArrayRef<fir::ExtendedValue>);
3234
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
3335
fir::ExtendedValue genAtomicCas(mlir::Type,
3436
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ static constexpr IntrinsicHandler cudaHandlers[]{
195195
false},
196196
{"atomicadd_r4x4",
197197
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
198-
&CI::genAtomicAddVector<4>),
198+
&CI::genAtomicAddVector4x4),
199199
{{{"a", asAddr}, {"v", asAddr}}},
200200
false},
201201
{"atomicaddd",
@@ -758,6 +758,56 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector(
758758
return fir::ArrayBoxValue(res, {ext});
759759
}
760760

761+
// ATOMICADDVECTOR4x4
762+
fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector4x4(
763+
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
764+
assert(args.size() == 2);
765+
mlir::Value a = fir::getBase(args[0]);
766+
if (mlir::isa<fir::BaseBoxType>(a.getType()))
767+
a = fir::BoxAddrOp::create(builder, loc, a);
768+
769+
const unsigned extent = 4;
770+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
771+
mlir::Value ptr = builder.createConvert(loc, llvmPtrTy, a);
772+
mlir::Type f32Ty = builder.getF32Type();
773+
mlir::Type idxTy = builder.getIndexType();
774+
mlir::Type refTy = fir::ReferenceType::get(f32Ty);
775+
llvm::SmallVector<mlir::Value> values;
776+
for (unsigned i = 0; i < extent; ++i) {
777+
mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i);
778+
mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy,
779+
fir::getBase(args[1]), pos);
780+
mlir::Value value = fir::LoadOp::create(builder, loc, coord);
781+
values.push_back(value);
782+
}
783+
784+
auto inlinePtx = mlir::NVVM::InlinePtxOp::create(
785+
builder, loc, {f32Ty, f32Ty, f32Ty, f32Ty},
786+
{ptr, values[0], values[1], values[2], values[3]}, {},
787+
"atom.add.v4.f32 {%0, %1, %2, %3}, [%4], {%5, %6, %7, %8};", {});
788+
789+
llvm::SmallVector<mlir::Value> results;
790+
results.push_back(inlinePtx.getResult(0));
791+
results.push_back(inlinePtx.getResult(1));
792+
results.push_back(inlinePtx.getResult(2));
793+
results.push_back(inlinePtx.getResult(3));
794+
795+
mlir::Type vecF32Ty = mlir::VectorType::get({extent}, f32Ty);
796+
mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF32Ty);
797+
mlir::Type i32Ty = builder.getI32Type();
798+
for (unsigned i = 0; i < extent; ++i)
799+
undef = mlir::LLVM::InsertElementOp::create(
800+
builder, loc, undef, results[i],
801+
builder.createIntegerConstant(loc, i32Ty, i));
802+
803+
auto i128Ty = builder.getIntegerType(128);
804+
auto i128VecTy = mlir::VectorType::get({1}, i128Ty);
805+
mlir::Value vec128 =
806+
mlir::vector::BitCastOp::create(builder, loc, i128VecTy, undef);
807+
return mlir::vector::ExtractOp::create(builder, loc, vec128,
808+
mlir::ArrayRef<int64_t>{0});
809+
}
810+
761811
mlir::Value
762812
CUDAIntrinsicLibrary::genAtomicAnd(mlir::Type resultType,
763813
llvm::ArrayRef<mlir::Value> args) {

flang/test/Lower/CUDA/cuda-atomicadd.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ attributes(global) subroutine test_atomicadd_r4x4()
3232
end subroutine
3333

3434
! CHECK-LABEL: func.func @_QPtest_atomicadd_r4x4() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
35-
! CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, vector<4xf32>
35+
! CHECK: atom.add.v4.f32

0 commit comments

Comments
 (0)