diff --git a/src/accumulate/accumulate.jl b/src/accumulate/accumulate.jl index 0aff1bf..ec9e0ee 100644 --- a/src/accumulate/accumulate.jl +++ b/src/accumulate/accumulate.jl @@ -175,7 +175,7 @@ function _accumulate_impl!( temp_flags::Union{Nothing, AbstractArray}=nothing, ) if isnothing(dims) - return if use_KA_algo(v, prefer_threads) + return if use_gpu_algo(backend, prefer_threads) accumulate_1d_gpu!( op, v, backend, alg; init, neutral, inclusive, diff --git a/src/accumulate/accumulate_nd.jl b/src/accumulate/accumulate_nd.jl index 5e213c2..0ba51a6 100644 --- a/src/accumulate/accumulate_nd.jl +++ b/src/accumulate/accumulate_nd.jl @@ -35,7 +35,7 @@ function accumulate_nd!( # Degenerate cases end - if !use_KA_algo(v, prefer_threads) + if !use_gpu_algo(backend, prefer_threads) _accumulate_nd_cpu_sections!(op, v; init, dims, inclusive, max_tasks, min_elems) else # On GPUs we have two parallelisation approaches, based on which dimension has more elements: diff --git a/src/foreachindex.jl b/src/foreachindex.jl index 24b9d78..388609a 100644 --- a/src/foreachindex.jl +++ b/src/foreachindex.jl @@ -130,7 +130,7 @@ function foreachindex( # GPU settings block_size=256, ) - if use_KA_algo(itr, prefer_threads) + if use_gpu_algo(backend, prefer_threads) _forindices_gpu(f, eachindex(itr), backend; block_size) else _forindices_threads(f, eachindex(itr); max_tasks, min_elems) @@ -232,7 +232,7 @@ function foraxes( ) end - if use_KA_algo(itr, prefer_threads) + if use_gpu_algo(backend, prefer_threads) _forindices_gpu(f, axes(itr, dims), backend; block_size) else _forindices_threads(f, axes(itr, dims); max_tasks, min_elems) diff --git a/src/predicates.jl b/src/predicates.jl index 1f0a0dd..47f1e59 100644 --- a/src/predicates.jl +++ b/src/predicates.jl @@ -119,7 +119,7 @@ function _any_impl( # GPU settings block_size::Int=256, ) - if use_KA_algo(v, prefer_threads) + if use_gpu_algo(backend, prefer_threads) @argcheck block_size > 0 # Some platforms crash when multiple threads write to the same memory location in a global @@ -253,7 +253,7 @@ function _all_impl( # GPU settings block_size::Int=256, ) - if use_KA_algo(v, prefer_threads) + if use_gpu_algo(backend, prefer_threads) @argcheck block_size > 0 # Some platforms crash when multiple threads write to the same memory location in a global diff --git a/src/reduce/mapreduce_nd.jl b/src/reduce/mapreduce_nd.jl index cf7d825..fba6f71 100644 --- a/src/reduce/mapreduce_nd.jl +++ b/src/reduce/mapreduce_nd.jl @@ -114,7 +114,7 @@ function mapreduce_nd( end dst_size = length(dst) - if !use_KA_algo(src, prefer_threads) + if !use_gpu_algo(backend, prefer_threads) _mapreduce_nd_cpu_sections!( f, op, dst, src; init, diff --git a/src/reduce/reduce.jl b/src/reduce/reduce.jl index 0332531..18906a4 100644 --- a/src/reduce/reduce.jl +++ b/src/reduce/reduce.jl @@ -183,7 +183,7 @@ function _mapreduce_impl( switch_below::Int=0, ) if isnothing(dims) - if use_KA_algo(src, prefer_threads) + if use_gpu_algo(backend, prefer_threads) mapreduce_1d_gpu( f, op, src, backend; init, neutral, diff --git a/src/sort/sort.jl b/src/sort/sort.jl index 8e55e3a..e81e0c1 100644 --- a/src/sort/sort.jl +++ b/src/sort/sort.jl @@ -96,7 +96,7 @@ function _sort_impl!( # Temporary buffer, same size as `v` temp::Union{Nothing, AbstractArray}=nothing, ) - if use_KA_algo(v, prefer_threads) + if use_gpu_algo(backend, prefer_threads) merge_sort!( v, backend; lt, by, rev, order, @@ -207,7 +207,7 @@ function _sortperm_impl!( # Temporary buffer, same size as `v` temp::Union{Nothing, AbstractArray}=nothing, ) - if use_KA_algo(v, prefer_threads) + if use_gpu_algo(backend, prefer_threads) merge_sortperm_lowmem!( ix, v, backend; lt, by, rev, order, diff --git a/src/utils.jl b/src/utils.jl index f601b44..86b3d47 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,8 +3,9 @@ function ispow2(x) end # Helper function to check whether the package cpu implementation of an algorithm should be used -@inline function use_KA_algo(output_array, prefer_threads) - return output_array isa AnyGPUArray || !prefer_threads +const CPU_BACKEND = get_backend([]) +@inline function use_gpu_algo(backend, prefer_threads) + return backend != CPU_BACKEND || !prefer_threads end """