@@ -5,31 +5,28 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55
66# Function to use as a baseline for CPU metrics
77function create_histogram (input)
8- histogram_output = zeros (Int , maximum (input))
8+ histogram_output = zeros (eltype (input) , maximum (input))
99 for i in input
1010 histogram_output[i] += 1
1111 end
1212 return histogram_output
1313end
1414
1515# This a 1D histogram kernel where the histogramming happens on shmem
16- @kernel function histogram_kernel! (histogram_output, input)
16+ @kernel unsafe_indices = true function histogram_kernel! (histogram_output, input)
1717 tid = @index (Global, Linear)
1818 lid = @index (Local, Linear)
1919
20- @uniform warpsize = Int (32 )
21-
22- @uniform gs = @groupsize ()[1 ]
20+ @uniform gs = prod (@groupsize ())
2321 @uniform N = length (histogram_output)
2422
25- shared_histogram = @localmem Int (gs)
23+ shared_histogram = @localmem eltype (input) (gs)
2624
2725 # This will go through all input elements and assign them to a location in
2826 # shmem. Note that if there is not enough shem, we create different shmem
2927 # blocks to write to. For example, if shmem is of size 256, but it's
3028 # possible to get a value of 312, then we will have 2 separate shmem blocks,
3129 # one from 1->256, and another from 256->512
32- @uniform max_element = 1
3330 for min_element in 1 : gs: N
3431
3532 # Setting shared_histogram to 0
4239 end
4340
4441 # Defining bin on shared memory and writing to it if possible
45- bin = input[tid]
42+ bin = tid <= length ( input) ? input [tid] : 0
4643 if bin >= min_element && bin < max_element
4744 bin -= min_element - 1
4845 @atomic shared_histogram[bin] += 1
5855
5956end
6057
61- function histogram! (histogram_output, input)
58+ function histogram! (histogram_output, input, groupsize = 256 )
6259 backend = get_backend (histogram_output)
6360 # Need static block size
64- kernel! = histogram_kernel! (backend, (256 ,))
61+ kernel! = histogram_kernel! (backend, (groupsize ,))
6562 kernel! (histogram_output, input, ndrange = size (input))
6663 return
6764end
@@ -74,9 +71,10 @@ function move(backend, input)
7471end
7572
7673@testset " histogram tests" begin
77- rand_input = [rand (1 : 128 ) for i in 1 : 1000 ]
78- linear_input = [i for i in 1 : 1024 ]
79- all_two = [2 for i in 1 : 512 ]
74+ # Use Int32 as some backends don't support 64-bit atomics
75+ rand_input = Int32 .(rand (1 : 128 , 1000 ))
76+ linear_input = Int32 .(1 : 1024 )
77+ all_two = fill (Int32 (2 ), 512 )
8078
8179 histogram_rand_baseline = create_histogram (rand_input)
8280 histogram_linear_baseline = create_histogram (linear_input)
8684 linear_input = move (backend, linear_input)
8785 all_two = move (backend, all_two)
8886
89- rand_histogram = KernelAbstractions. zeros (backend, Int, 128 )
90- linear_histogram = KernelAbstractions. zeros (backend, Int, 1024 )
91- two_histogram = KernelAbstractions. zeros (backend, Int, 2 )
87+ rand_histogram = KernelAbstractions. zeros (backend, eltype (rand_input), Int ( maximum (rand_input)) )
88+ linear_histogram = KernelAbstractions. zeros (backend, eltype (linear_input), Int ( maximum (linear_input)) )
89+ two_histogram = KernelAbstractions. zeros (backend, eltype (all_two), Int ( maximum (all_two)) )
9290
93- histogram! (rand_histogram, rand_input)
91+ histogram! (rand_histogram, rand_input, 6 )
9492 histogram! (linear_histogram, linear_input)
9593 histogram! (two_histogram, all_two)
96- KernelAbstractions. synchronize (CPU () )
94+ KernelAbstractions. synchronize (backend )
9795
9896 @test isapprox (Array (rand_histogram), histogram_rand_baseline)
9997 @test isapprox (Array (linear_histogram), histogram_linear_baseline)
0 commit comments