diff --git a/src/accumulate/accumulate.jl b/src/accumulate/accumulate.jl index e88fd18..0b00beb 100644 --- a/src/accumulate/accumulate.jl +++ b/src/accumulate/accumulate.jl @@ -139,7 +139,7 @@ end function accumulate!( - op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(v); + op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(dst); init, neutral=GPUArrays.neutral_element(op, eltype(dst)), dims::Union{Nothing, Int}=nothing, diff --git a/test/runtests.jl b/test/runtests.jl index d3ceca3..5907119 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1307,9 +1307,9 @@ end for _ in 1:100 num_elems = rand(1:100_000) x = array_from_host(rand(1:1000, num_elems), Int32) - y = copy(x) + y = similar(x) init = rand(-1000:1000) - AK.accumulate!(+, y; init=Int32(init)) + AK.accumulate!(+, y, x; init=Int32(init)) @test all(Array(y) .== accumulate(+, Array(x), init=init)) end @@ -1395,7 +1395,7 @@ end n3 = rand(1:100) vh = rand(Float32, n1, n2, n3) v = array_from_host(vh) - + s = AK.accumulate(+, v; init=Float32(0), dims=dims) sh = Array(s) @test all(sh .≈ accumulate(+, vh, init=Float32(0), dims=dims)) @@ -1530,7 +1530,7 @@ end AK.searchsortedfirst(v, x, by=abs, lt=(>), rev=true, block_size=64) AK.searchsortedlast!(ix, v, x, by=abs, lt=(>), rev=true, block_size=64) AK.searchsortedlast(v, x, by=abs, lt=(>), rev=true, block_size=64) - + vh = Array(v) xh = Array(x) ixh = similar(xh, Int32) @@ -1809,7 +1809,7 @@ end @testset "cumsum" begin - + Random.seed!(0) # Simple correctness tests @@ -1858,7 +1858,7 @@ end @testset "cumprod" begin - + Random.seed!(0) # Simple correctness tests