diff --git a/ext/AtomixCUDAExt.jl b/ext/AtomixCUDAExt.jl index fc92895..1948430 100644 --- a/ext/AtomixCUDAExt.jl +++ b/ext/AtomixCUDAExt.jl @@ -2,16 +2,67 @@ module AtomixCUDAExt using Atomix: Atomix, IndexableRef +using Atomix.Internal: UnsafeAtomics using CUDA: CUDA, CuDeviceArray const CuIndexableRef{Indexable<:CuDeviceArray} = IndexableRef{Indexable} +# from https://github.com/JuliaGPU/CUDA.jl/pull/1644 +# function atomic_load(ptr::LLVMPtr{T}, order, scope::System=System()) where T +# if order == Acq_Rel() || order == Release() +# assert(false) +# end +# if compute_capability() >= sv"7.0" +# if order == Relaxed() +# val = __load(ptr, Relaxed(), scope) +# return val +# end +# if order == Seq_Cst() +# atomic_thread_fence(Seq_Cst(), scope) +# end +# val = __load(ptr, Acquire(), scope) +# return val +# else +# if order == Seq_Cst() +# atomic_thread_fence(Seq_Cst(), scope) +# end +# val = __load_volatile(ptr) +# if order == Relaxed() +# return val +# end +# atomic_thread_fence(order, scope) +# return val +# end +# end + function Atomix.get(ref::CuIndexableRef, order) - error("not implemented") + ptr = Atomix.pointer(ref) + return UnsafeAtomics.load(ptr, UnsafeAtomics.monotonic) end +# function atomic_store!(ptr::LLVMPtr{T}, val::T, order, scope::System=System()) where T +# if order == Acq_Rel() || order == Consume() || order == Acquire() +# assert(false) +# end +# if compute_capability() >= sv"7.0" +# if order == Release() +# __store!(ptr, val, Release(), scope) +# return +# end +# if order == Seq_Cst() +# atomic_thread_fence(Seq_Cst(), scope) +# end +# __store!(ptr, val, Relaxed(), scope) +# else +# if order == Seq_Cst() +# atomic_thread_fence(Seq_Cst(), scope) +# end +# __store_volatile!(ptr, val) +# end +# end function Atomix.set!(ref::CuIndexableRef, v, order) - error("not implemented") + ptr = Atomix.pointer(ref) + return UnsafeAtomics.store!(ptr, v, UnsafeAtomics.monotonic) end @inline function Atomix.replace!( @@ -51,7 +102,7 @@ end elseif op === Atomix.right CUDA.atomic_xchg!(ptr, x) else - error("not implemented") + return UnsafeAtomics.modify(ptr, op, x, UnsafeAtomics.monotonic) end end return old => op(old, x) diff --git a/test/test_atomix_cuda.jl b/test/test_atomix_cuda.jl index 487eae2..dd153db 100644 --- a/test/test_atomix_cuda.jl +++ b/test/test_atomix_cuda.jl @@ -15,10 +15,7 @@ function cuda(f) CUDA.@cuda g() end - -# Not implemented: -#= -function test_get_set() +@testset "AtomixCUDAExt:test_get_set" begin A = CUDA.ones(Int, 3) cuda() do GC.@preserve A begin @@ -29,8 +26,6 @@ function test_get_set() end @test collect(A) == [-1, 1, 1] end -=# - @testset "AtomixCUDAExt:test_cas" begin idx = (