Skip to content

Commit 684b507

Browse files
committed
Split in device and host folders.
1 parent ed192a4 commit 684b507

15 files changed

+46
-163
lines changed

src/GPUArrays.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,27 @@ using AbstractFFTs
1212

1313
using Adapt
1414

15-
# GPU interface
16-
## core definition
17-
include("abstractarray.jl")
18-
include("devices.jl")
19-
include("execution.jl")
20-
include("ondevice.jl")
21-
include("construction.jl")
15+
# device array
16+
include("device/abstractarray.jl")
17+
include("device/indexing.jl")
18+
include("device/gpu.jl")
19+
20+
# host array
21+
include("host/abstractarray.jl")
22+
include("host/devices.jl")
23+
include("host/execution.jl")
24+
include("host/construction.jl")
2225
## integrations and specialized functionality
23-
include("base.jl")
24-
include("indexing.jl")
25-
include("broadcast.jl")
26-
include("mapreduce.jl")
27-
include("linalg.jl")
28-
include("random.jl")
29-
30-
# CPU implementation
26+
include("host/base.jl")
27+
include("host/indexing.jl")
28+
include("host/broadcast.jl")
29+
include("host/mapreduce.jl")
30+
include("host/linalg.jl")
31+
include("host/random.jl")
32+
include("host/quirks.jl")
33+
34+
# CPU reference implementation
3135
include("array.jl")
3236

33-
include("quirks.jl")
3437

3538
end # module

src/array.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# CPU implementation of the GPUArray interface
1+
# reference implementation of the GPUArray interfaces
22

33
export JLArray
44

@@ -88,8 +88,6 @@ to_device(state, x::JLArray) = x.data
8888
to_device(state, x::Tuple) = to_device.(Ref(state), x)
8989
to_device(state, x::Base.RefValue{<: JLArray}) = Base.RefValue(to_device(state, x[]))
9090
to_device(state, x) = x
91-
# creates a `local` vector for each thread group
92-
to_device(state, x::LocalMemory{T}) where T = LocalMem(ntuple(i-> Vector{T}(x.size), blockdim_x(state)))
9391

9492
to_blocks(state, x) = x
9593
# unpacks local memory for each block

src/device/abstractarray.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# on-device functionality
2+
3+
export AbstractDeviceArray
4+
5+
6+
## device array
7+
8+
abstract type AbstractDeviceArray{T, N} <: AbstractArray{T, N} end
9+
10+
Base.IndexStyle(::AbstractDeviceArray) = IndexLinear()
11+
12+
@inline function Base.iterate(A::AbstractDeviceArray, i=1)
13+
if (i % UInt) - 1 < length(A)
14+
(@inbounds A[i], i + 1)
15+
else
16+
nothing
17+
end
18+
end
19+
20+
function Base.sum(A::AbstractDeviceArray{T}) where T
21+
acc = zero(T)
22+
for elem in A
23+
acc += elem
24+
end
25+
acc
26+
end

src/abstractarray.jl renamed to src/host/abstractarray.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ const GPUVector{T} = GPUArray{T, 1}
1111
const GPUMatrix{T} = GPUArray{T, 2}
1212
const GPUVecOrMat{T} = Union{GPUArray{T, 1}, GPUArray{T, 2}}
1313

14-
# GPU Local Memory
15-
struct LocalMemory{T} <: GPUArray{T, 1}
16-
size::Int
17-
LocalMemory{T}(x::Integer) where T = new{T}(x)
18-
end
19-
20-
2114
# input/output
2215

2316
## serialization
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)