Skip to content

Commit ed192a4

Browse files
committed
More clean-up and reorganization.
1 parent 7437888 commit ed192a4

15 files changed

+160
-213
lines changed

src/GPUArrays.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
module GPUArrays
22

3-
export GPUArray, gpu_call, thread_blocks_heuristic, global_size, synchronize_threads
4-
export linear_index, @linearidx, @cartesianidx, convolution!, device, synchronize
5-
export JLArray
6-
73
using Serialization
84
using Random
95
using LinearAlgebra
@@ -17,18 +13,18 @@ using AbstractFFTs
1713
using Adapt
1814

1915
# GPU interface
16+
## core definition
2017
include("abstractarray.jl")
21-
include("abstract_gpu_interface.jl")
18+
include("devices.jl")
19+
include("execution.jl")
2220
include("ondevice.jl")
23-
include("base.jl")
2421
include("construction.jl")
25-
include("blas.jl")
26-
include("broadcast.jl")
27-
include("devices.jl")
28-
include("heuristics.jl")
22+
## integrations and specialized functionality
23+
include("base.jl")
2924
include("indexing.jl")
30-
include("linalg.jl")
25+
include("broadcast.jl")
3126
include("mapreduce.jl")
27+
include("linalg.jl")
3228
include("random.jl")
3329

3430
# CPU implementation

src/abstractarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# core definition of the GPUArray type
22

3+
export GPUArray
4+
35
abstract type GPUArray{T, N} <: DenseArray{T, N} end
46

57
# Sampler type that acts like a texture/image and allows interpolated access

src/array.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# Very simple Julia back-end which is just for testing the implementation and can be used as
2-
# a reference implementation
1+
# CPU implementation of the GPUArray interface
2+
3+
export JLArray
34

45
struct JLArray{T, N} <: GPUArray{T, N}
56
data::Array{T, N}

src/base.jl

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# common Base functionality
2+
13
allequal(x) = true
24
allequal(x, y, z...) = x == y && allequal(y, z...)
35
function Base.map!(f, y::GPUArray, xs::GPUArray...)
@@ -18,45 +20,6 @@ Base.map!(f, y::GPUArray, x1::GPUArray, x2::GPUArray) =
1820
invoke(map!, Tuple{Any,GPUArray, Vararg{GPUArray}}, f, y, x1, x2)
1921

2022

21-
# TODO find out why this segfaults julia without stack trace on AMD
22-
# produces wrong results on Titan X and passes on GTX 950..........
23-
24-
# @generated function nindex(i::T, ls::NTuple{N}) where {T, N}
25-
# quote
26-
# Base.@_inline_meta
27-
# $(foldr(:($T(0), $T(0)), T(1):T(N)) do n, els
28-
# :(i ≤ ls[$n] ? ($T($n), i) : (i -= $T(ls[$n]); $els))
29-
# end)
30-
# end
31-
# end
32-
# function catindex(dim, I::NTuple{N, T}, shapes) where {T, N}
33-
# xi = nindex(I[dim], map(s-> s[dim], shapes))
34-
# x = xi[1]; i = xi[2]
35-
# x, ntuple(n -> n == dim ? i : I[n], Val{N})
36-
# end
37-
#
38-
# function _cat(dim, dest, xs...)
39-
# gpu_call(dest, (Int(dim), dest, xs)) do state, dim, dest, xs
40-
# I = @cartesianidx dest state
41-
# nI = catindex(dim, I, size.(xs))
42-
# n = nI[1]; I′ = nI[2]
43-
# @inbounds dest[I...] = xs[n][I′...]
44-
# return
45-
# end
46-
# return dest
47-
# end
48-
#
49-
# function Base.cat_t(dims::Integer, T::Type, x::GPUArray, xs::GPUArray...)
50-
# catdims = Base.dims2cat(dims)
51-
# shape = Base.cat_shape(catdims, (), size.((x, xs...))...)
52-
# dest = Base.cat_similar(x, T, shape)
53-
# _cat(dims, dest, x, xs...)
54-
# end
55-
#
56-
# Base.vcat(xs::GPUArray...) = cat(1, xs...)
57-
# Base.hcat(xs::GPUArray...) = cat(2, xs...)
58-
59-
6023
# Base functions that are sadly not fit for the the GPU yet (they only work for Int64)
6124
Base.@pure @inline function gpu_ind2sub(A::AbstractArray, ind::T) where T
6225
_ind2sub(size(A), ind - T(1))

src/blas.jl

Lines changed: 0 additions & 116 deletions
This file was deleted.

src/broadcast.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# broadcasting operations
2+
13
using Base.Broadcast
24

35
import Base.Broadcast: BroadcastStyle, Broadcasted, ArrayStyle

src/construction.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# constructors and conversions
2+
13
function Base.fill(X::Type{<: GPUArray}, val::T, dims::NTuple{N, Integer}) where {T, N}
24
res = similar(X{T}, dims)
35
fill!(res, val)

src/devices.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
# device properties
1+
# device management and properties
2+
3+
"""
4+
device(A::AbstractArray)
5+
6+
Gets the device associated to the Array `A`
7+
"""
8+
device(A::AbstractArray) = error("Not implemented") # COV_EXCL_LINE
29

310
"""
411
Hardware threads of device

src/abstract_gpu_interface.jl renamed to src/execution.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
# kernel execution
2+
3+
export gpu_call, synchronize, thread_blocks_heuristic
4+
15
abstract type GPUBackend end
6+
27
backend(::Type{T}) where T = error("Can't choose GPU backend for $T")
38

49
"""
@@ -46,17 +51,6 @@ end
4651
# Internal GPU call function, that needs to be overloaded by the backends.
4752
_gpu_call(::Any, f, A, args, thread_blocks) = error("Not implemented") # COV_EXCL_LINE
4853

49-
50-
"""
51-
device(A::AbstractArray)
52-
53-
Gets the device associated to the Array `A`
54-
"""
55-
function device(A::AbstractArray)
56-
# fallback is a noop, for backends not needing synchronization. This
57-
# makes it easier to write generic code that also works for AbstractArrays
58-
end
59-
6054
"""
6155
synchronize(A::AbstractArray)
6256
@@ -66,3 +60,10 @@ function synchronize(A::AbstractArray)
6660
# fallback is a noop, for backends not needing synchronization. This
6761
# makes it easier to write generic code that also works for AbstractArrays
6862
end
63+
64+
function thread_blocks_heuristic(len::Integer)
65+
# TODO better threads default
66+
threads = clamp(len, 1, 256)
67+
blocks = max(ceil(Int, len / threads), 1)
68+
(blocks,), (threads,)
69+
end

src/heuristics.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)