@@ -5,7 +5,7 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55
66# Function to use as a baseline for CPU metrics
77function create_histogram (input)
8- histogram_output = zeros (Int , maximum (input))
8+ histogram_output = zeros (eltype (input) , maximum (input))
99 for i in input
1010 histogram_output[i] += 1
1111 end
2222 @uniform gs = @groupsize ()[1 ]
2323 @uniform N = length (histogram_output)
2424
25- shared_histogram = @localmem Int (gs)
25+ shared_histogram = @localmem eltype (input) (gs)
2626
2727 # This will go through all input elements and assign them to a location in
2828 # shmem. Note that if there is not enough shem, we create different shmem
@@ -74,9 +74,10 @@ function move(backend, input)
7474end
7575
7676@testset " histogram tests" begin
77- rand_input = [rand (1 : 128 ) for i in 1 : 1000 ]
78- linear_input = [i for i in 1 : 1024 ]
79- all_two = [2 for i in 1 : 512 ]
77+ # Use Int32 as some backends don't support 64-bit atomics
78+ rand_input = Int32 .(rand (1 : 128 , 1000 ))
79+ linear_input = Int32 .(rand (1 : 128 , 1024 ))
80+ all_two = fill (Int32 (2 ), 512 )
8081
8182 histogram_rand_baseline = create_histogram (rand_input)
8283 histogram_linear_baseline = create_histogram (linear_input)
8687 linear_input = move (backend, linear_input)
8788 all_two = move (backend, all_two)
8889
89- rand_histogram = KernelAbstractions. zeros (backend, Int , 128 )
90- linear_histogram = KernelAbstractions. zeros (backend, Int , 1024 )
91- two_histogram = KernelAbstractions. zeros (backend, Int , 2 )
90+ rand_histogram = KernelAbstractions. zeros (backend, eltype (rand_input) , 128 )
91+ linear_histogram = KernelAbstractions. zeros (backend, eltype (linear_input) , 1024 )
92+ two_histogram = KernelAbstractions. zeros (backend, eltype (all_two) , 2 )
9293
9394 histogram! (rand_histogram, rand_input)
9495 histogram! (linear_histogram, linear_input)
9596 histogram! (two_histogram, all_two)
96- KernelAbstractions. synchronize (CPU () )
97+ KernelAbstractions. synchronize (backend )
9798
9899 @test isapprox (Array (rand_histogram), histogram_rand_baseline)
99100 @test isapprox (Array (linear_histogram), histogram_linear_baseline)
0 commit comments