diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index b679ef74870b1..f5971610694f0 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -187,12 +187,15 @@ struct IntrinsicLibrary { mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef); - mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef); + mlir::Value genAtomicCas(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicDec(mlir::Type, llvm::ArrayRef); + mlir::Value genAtomicExch(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicInc(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicMax(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicMin(mlir::Type, llvm::ArrayRef); + mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef); mlir::Value genAtomicSub(mlir::Type, llvm::ArrayRef); + mlir::Value genAtomicXor(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genCommandArgumentCount(mlir::Type, llvm::ArrayRef); mlir::Value genAsind(mlir::Type, llvm::ArrayRef); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index d98ee58ace2bc..28fbe83defb61 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -152,7 +152,39 @@ static constexpr IntrinsicHandler handlers[]{ {"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false}, + {"atomiccasd", + &I::genAtomicCas, + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasf", + &I::genAtomicCas, + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasi", + &I::genAtomicCas, + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, + {"atomiccasul", + &I::genAtomicCas, + {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}}, + false}, {"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false}, + {"atomicexchd", + &I::genAtomicExch, + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchf", + &I::genAtomicExch, + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchi", + &I::genAtomicExch, + {{{"a", asAddr}, {"v", asValue}}}, + false}, + {"atomicexchul", + &I::genAtomicExch, + {{{"a", asAddr}, {"v", asValue}}}, + false}, {"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false}, @@ -167,6 +199,7 @@ static constexpr IntrinsicHandler handlers[]{ {"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, {"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false}, + {"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false}, {"bessel_jn", &I::genBesselJn, {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}}, @@ -2691,6 +2724,22 @@ mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType, return genAtomBinOp(builder, loc, binOp, args[0], args[1]); } +// ATOMICCAS +mlir::Value IntrinsicLibrary::genAtomicCas(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 3); + assert(args[1].getType() == args[2].getType()); + auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic; + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext()); + auto address = + builder.create(loc, llvmPtrTy, args[0]) + .getResult(0); + auto cmpxchg = builder.create( + loc, address, args[1], args[2], successOrdering, failureOrdering); + return builder.create(loc, cmpxchg, 1); +} + mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType, llvm::ArrayRef args) { assert(args.size() == 2); @@ -2700,6 +2749,16 @@ mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType, return genAtomBinOp(builder, loc, binOp, args[0], args[1]); } +// ATOMICEXCH +mlir::Value IntrinsicLibrary::genAtomicExch(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 2); + assert(mlir::isa(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType, llvm::ArrayRef args) { assert(args.size() == 2); @@ -2731,6 +2790,16 @@ mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType, return genAtomBinOp(builder, loc, binOp, args[0], args[1]); } +// ATOMICXOR +mlir::Value IntrinsicLibrary::genAtomicXor(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 2); + assert(mlir::isa(args[1].getType())); + + mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_xor; + return genAtomBinOp(builder, loc, binOp, args[0], args[1]); +} + // ASSOCIATED fir::ExtendedValue IntrinsicLibrary::genAssociated(mlir::Type resultType, diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index 8b31c0c0856fd..af8ea66618e27 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -557,59 +557,117 @@ attributes(device) pure integer function atomicdeci(address, val) end function end interface + interface atomiccas + attributes(device) pure integer function atomiccasi(address, val, val2) + !dir$ ignore_tkr (rd) address, (d) val, (d) val2 + integer, intent(inout) :: address + integer, value :: val, val2 + end function + attributes(device) pure integer(8) function atomiccasul(address, val, val2) + !dir$ ignore_tkr (rd) address, (dk) val, (dk) val2 + integer(8), intent(inout) :: address + integer(8), value :: val, val2 + end function + attributes(device) pure real function atomiccasf(address, val, val2) + !dir$ ignore_tkr (rd) address, (d) val, (d) val2 + real, intent(inout) :: address + real, value :: val, val2 + end function + attributes(device) pure double precision function atomiccasd(address, val, val2) + !dir$ ignore_tkr (rd) address, (d) val, (d) val2 + double precision, intent(inout) :: address + double precision, value :: val, val2 + end function + end interface + + interface atomicexch + attributes(device) pure integer function atomicexchi(address, val) + !dir$ ignore_tkr (rd) address, (d) val + integer, intent(inout) :: address + integer, value :: val + end function + attributes(device) pure integer(8) function atomicexchul(address, val) + !dir$ ignore_tkr (rd) address, (dk) val + integer(8), intent(inout) :: address + integer(8), value :: val + end function + attributes(device) pure real function atomicexchf(address, val) + !dir$ ignore_tkr (rd) address, (d) val + real, intent(inout) :: address + real, value :: val + end function + attributes(device) pure double precision function atomicexchd(address, val) + !dir$ ignore_tkr (rd) address, (d) val + double precision, intent(inout) :: address + double precision, value :: val + end function + end interface + + interface atomicxor + attributes(device) pure integer function atomicxori(address, val) + !dir$ ignore_tkr (rd) address, (d) val + integer, intent(inout) :: address + integer, value :: val + end function + end interface + + ! Time function + interface attributes(device) integer(8) function clock64() end function end interface -interface match_all_sync - attributes(device) integer function match_all_syncjj(mask, val, pred) -!dir$ ignore_tkr(d) mask, (d) val, (d) pred - integer(4), value :: mask - integer(4), value :: val - integer(4) :: pred - end function - attributes(device) integer function match_all_syncjx(mask, val, pred) -!dir$ ignore_tkr(d) mask, (d) val, (d) pred - integer(4), value :: mask - integer(8), value :: val - integer(4) :: pred - end function - attributes(device) integer function match_all_syncjf(mask, val, pred) -!dir$ ignore_tkr(d) mask, (d) val, (d) pred - integer(4), value :: mask - real(4), value :: val - integer(4) :: pred - end function - attributes(device) integer function match_all_syncjd(mask, val, pred) -!dir$ ignore_tkr(d) mask, (d) val, (d) pred - integer(4), value :: mask - real(8), value :: val - integer(4) :: pred - end function -end interface - -interface match_any_sync - attributes(device) integer function match_any_syncjj(mask, val) -!dir$ ignore_tkr(d) mask, (d) val - integer(4), value :: mask - integer(4), value :: val - end function - attributes(device) integer function match_any_syncjx(mask, val) -!dir$ ignore_tkr(d) mask, (d) val - integer(4), value :: mask - integer(8), value :: val - end function - attributes(device) integer function match_any_syncjf(mask, val) -!dir$ ignore_tkr(d) mask, (d) val - integer(4), value :: mask - real(4), value :: val - end function - attributes(device) integer function match_any_syncjd(mask, val) -!dir$ ignore_tkr(d) mask, (d) val - integer(4), value :: mask - real(8), value :: val - end function -end interface + ! Warp Match Functions + + interface match_all_sync + attributes(device) integer function match_all_syncjj(mask, val, pred) + !dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + integer(4), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjx(mask, val, pred) + !dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + integer(8), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjf(mask, val, pred) + !dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + real(4), value :: val + integer(4) :: pred + end function + attributes(device) integer function match_all_syncjd(mask, val, pred) + !dir$ ignore_tkr(d) mask, (d) val, (d) pred + integer(4), value :: mask + real(8), value :: val + integer(4) :: pred + end function + end interface + + interface match_any_sync + attributes(device) integer function match_any_syncjj(mask, val) + !dir$ ignore_tkr(d) mask, (d) val + integer(4), value :: mask + integer(4), value :: val + end function + attributes(device) integer function match_any_syncjx(mask, val) + !dir$ ignore_tkr(d) mask, (d) val + integer(4), value :: mask + integer(8), value :: val + end function + attributes(device) integer function match_any_syncjf(mask, val) + !dir$ ignore_tkr(d) mask, (d) val + integer(4), value :: mask + real(4), value :: val + end function + attributes(device) integer function match_any_syncjd(mask, val) + !dir$ ignore_tkr(d) mask, (d) val + integer(4), value :: mask + real(8), value :: val + end function + end interface end module diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index e7d1dba385bb8..fcfcc2e537039 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -150,15 +150,15 @@ end subroutine ! CHECK: fir.convert %{{.*}} : (f64) -> i64 ! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p -! CHECK: func.func private @llvm.nvvm.barrier0() -! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32) -! CHECK: func.func private @llvm.nvvm.membar.gl() -! CHECK: func.func private @llvm.nvvm.membar.cta() -! CHECK: func.func private @llvm.nvvm.membar.sys() -! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32 -! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32 -! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32 -! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple -! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple -! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32 -! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32 +attributes(device) subroutine testAtomic() + integer :: a, istat, j + istat = atomicexch(a,0) + istat = atomicxor(a, j) + istat = atomiccas(a, i, 14) +end subroutine + +! CHECK-LABEL: func.func @_QPtestatomic() +! CHECK: llvm.atomicrmw xchg %{{.*}}, %c0{{.*}} seq_cst : !llvm.ptr, i32 +! CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32 +! CHECK: %[[ADDR:.*]] = builtin.unrealized_conversion_cast %{{.*}}#1 : !fir.ref to !llvm.ptr +! CHECK: llvm.cmpxchg %[[ADDR]], %{{.*}}, %c14{{.*}} acq_rel monotonic : !llvm.ptr, i32