Skip to content

Commit 2084a80

Browse files
committed
Remove GPUArrays dependency
1 parent 0b99fbf commit 2084a80

File tree

7 files changed

+35
-21
lines changed

7 files changed

+35
-21
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ version = "0.3.3"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
8-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
8+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
99
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1010
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1111
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
@@ -21,7 +21,7 @@ AcceleratedKernelsoneAPIExt = "oneAPI"
2121

2222
[compat]
2323
ArgCheck = "2"
24-
GPUArrays = "10, 11"
24+
GPUArraysCore = "0.2.0"
2525
KernelAbstractions = "0.9.34"
2626
Markdown = "1"
2727
Metal = "1"

src/AcceleratedKernels.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module AcceleratedKernels
1212

1313
# Internal dependencies
1414
using ArgCheck: @argcheck
15-
using GPUArrays: GPUArrays, AbstractGPUVector, AbstractGPUArray, @allowscalar
15+
using GPUArraysCore: GPUArrays, AbstractGPUVector, AbstractGPUArray, @allowscalar
1616
using KernelAbstractions
1717
using Polyester: @batch
1818
import OhMyThreads as OMT
@@ -21,7 +21,6 @@ import OhMyThreads as OMT
2121
# Exposed functions from upstream packages
2222
const synchronize = KernelAbstractions.synchronize
2323
const get_backend = KernelAbstractions.get_backend
24-
const neutral_element = GPUArrays.neutral_element
2524

2625

2726
# Include code from other files

src/accumulate/accumulate.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ include("accumulate_cpu.jl")
3131
accumulate!(
3232
op, v::AbstractArray, backend::Backend=get_backend(v);
3333
init,
34-
neutral=GPUArrays.neutral_element(op, eltype(v)),
34+
neutral=neutral_element(op, eltype(v)),
3535
dims::Union{Nothing, Int}=nothing,
3636
inclusive::Bool=true,
3737
@@ -47,7 +47,7 @@ include("accumulate_cpu.jl")
4747
accumulate!(
4848
op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(v);
4949
init,
50-
neutral=GPUArrays.neutral_element(op, eltype(dst)),
50+
neutral=neutral_element(op, eltype(dst)),
5151
dims::Union{Nothing, Int}=nothing,
5252
inclusive::Bool=true,
5353
@@ -117,7 +117,7 @@ AK.accumulate!(+, v, alg=AK.ScanPrefixes())
117117
function accumulate!(
118118
op, v::AbstractArray, backend::Backend=get_backend(v);
119119
init,
120-
neutral=GPUArrays.neutral_element(op, eltype(v)),
120+
neutral=neutral_element(op, eltype(v)),
121121
dims::Union{Nothing, Int}=nothing,
122122
inclusive::Bool=true,
123123

@@ -141,7 +141,7 @@ end
141141
function accumulate!(
142142
op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(v);
143143
init,
144-
neutral=GPUArrays.neutral_element(op, eltype(dst)),
144+
neutral=neutral_element(op, eltype(dst)),
145145
dims::Union{Nothing, Int}=nothing,
146146
inclusive::Bool=true,
147147

@@ -166,7 +166,7 @@ end
166166
function _accumulate_impl!(
167167
op, v::AbstractArray, backend::Backend;
168168
init,
169-
neutral=GPUArrays.neutral_element(op, eltype(v)),
169+
neutral=neutral_element(op, eltype(v)),
170170
dims::Union{Nothing, Int}=nothing,
171171
inclusive::Bool=true,
172172

@@ -211,7 +211,7 @@ end
211211
accumulate(
212212
op, v::AbstractArray, backend::Backend=get_backend(v);
213213
init,
214-
neutral=GPUArrays.neutral_element(op, eltype(v)),
214+
neutral=neutral_element(op, eltype(v)),
215215
dims::Union{Nothing, Int}=nothing,
216216
inclusive::Bool=true,
217217
@@ -229,7 +229,7 @@ Out-of-place version of [`accumulate!`](@ref).
229229
function accumulate(
230230
op, v::AbstractArray, backend::Backend=get_backend(v);
231231
init,
232-
neutral=GPUArrays.neutral_element(op, eltype(v)),
232+
neutral=neutral_element(op, eltype(v)),
233233
dims::Union{Nothing, Int}=nothing,
234234
inclusive::Bool=true,
235235

src/accumulate/accumulate_nd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ end
254254
function accumulate_nd!(
255255
op, v::AbstractArray, backend::GPU;
256256
init,
257-
neutral=GPUArrays.neutral_element(op, eltype(v)),
257+
neutral=neutral_element(op, eltype(v)),
258258
dims::Int,
259259
inclusive::Bool=true,
260260

src/reduce/mapreduce_1d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ end
102102
function mapreduce_1d(
103103
f, op, src::AbstractArray, backend::GPU;
104104
init,
105-
neutral=GPUArrays.neutral_element(op, eltype(src)),
105+
neutral=neutral_element(op, eltype(src)),
106106

107107
block_size::Int=256,
108108
temp::Union{Nothing, AbstractArray}=nothing,

src/reduce/mapreduce_nd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ end
188188
function mapreduce_nd(
189189
f, op, src::AbstractArray, backend::GPU;
190190
init,
191-
neutral=GPUArrays.neutral_element(op, eltype(src)),
191+
neutral=neutral_element(op, eltype(src)),
192192
dims::Int,
193193
block_size::Int=256,
194194
temp::Union{Nothing, AbstractArray}=nothing,
@@ -320,7 +320,7 @@ end
320320
function _mapreduce_nd_apply_init!(f, op, dst, src, backend, init, block_size)
321321
foreachindex(
322322
dst, backend,
323-
block_size=block_size,
323+
block_size=block_size,
324324
) do i
325325
dst[i] = op(init, f(src[i]))
326326
end

src/reduce/reduce.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# neutral_element moved over from GPUArrays.jl
2+
neutral_element(op, T) =
3+
error("""AcceleratedKernels.jl needs to know the neutral element for your operator `$op`.
4+
Please pass it as an explicit keyword argument `neutral`.""")
5+
neutral_element(::typeof(Base.:(|)), T) = zero(T)
6+
neutral_element(::typeof(Base.:(+)), T) = zero(T)
7+
neutral_element(::typeof(Base.add_sum), T) = zero(T)
8+
neutral_element(::typeof(Base.:(&)), T) = one(T)
9+
neutral_element(::typeof(Base.:(*)), T) = one(T)
10+
neutral_element(::typeof(Base.mul_prod), T) = one(T)
11+
neutral_element(::typeof(Base.min), T) = typemax(T)
12+
neutral_element(::typeof(Base.max), T) = typemin(T)
13+
neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = typemax(T), typemin(T)
14+
15+
116
include("mapreduce_1d.jl")
217
include("mapreduce_nd.jl")
318

@@ -6,7 +21,7 @@ include("mapreduce_nd.jl")
621
reduce(
722
op, src::AbstractArray, backend::Backend=get_backend(src);
823
init,
9-
neutral=GPUArrays.neutral_element(op, eltype(src)),
24+
neutral=neutral_element(op, eltype(src)),
1025
dims::Union{Nothing, Int}=nothing,
1126
1227
# CPU settings
@@ -72,7 +87,7 @@ mcolsum = AK.reduce(+, m; init=zero(eltype(m)), dims=2)
7287
function reduce(
7388
op, src::AbstractArray, backend::Backend=get_backend(src);
7489
init,
75-
neutral=GPUArrays.neutral_element(op, eltype(src)),
90+
neutral=neutral_element(op, eltype(src)),
7691
dims::Union{Nothing, Int}=nothing,
7792

7893
# CPU settings
@@ -103,7 +118,7 @@ end
103118
function _reduce_impl(
104119
op, src::AbstractArray, backend;
105120
init,
106-
neutral=GPUArrays.neutral_element(op, eltype(src)),
121+
neutral=neutral_element(op, eltype(src)),
107122
dims::Union{Nothing, Int}=nothing,
108123

109124
# CPU settings
@@ -137,7 +152,7 @@ end
137152
mapreduce(
138153
f, op, src::AbstractArray, backend::Backend=get_backend(src);
139154
init,
140-
neutral=GPUArrays.neutral_element(op, eltype(src)),
155+
neutral=neutral_element(op, eltype(src)),
141156
dims::Union{Nothing, Int}=nothing,
142157
143158
# CPU settings
@@ -203,7 +218,7 @@ mcolsumsq = AK.mapreduce(f, +, m; init=zero(eltype(m)), dims=2)
203218
function mapreduce(
204219
f, op, src::AbstractArray, backend::Backend=get_backend(src);
205220
init,
206-
neutral=GPUArrays.neutral_element(op, eltype(src)),
221+
neutral=neutral_element(op, eltype(src)),
207222
dims::Union{Nothing, Int}=nothing,
208223

209224
# CPU settings
@@ -234,7 +249,7 @@ end
234249
function _mapreduce_impl(
235250
f, op, src::AbstractArray, backend::Backend;
236251
init,
237-
neutral=GPUArrays.neutral_element(op, eltype(src)),
252+
neutral=neutral_element(op, eltype(src)),
238253
dims::Union{Nothing, Int}=nothing,
239254

240255
# CPU settings

0 commit comments

Comments
 (0)