Skip to content

Commit 4b05834

Browse files
committed
fix mapreduce, make jlbackend async
1 parent 72a0637 commit 4b05834

16 files changed

+296
-110
lines changed

src/GPUArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ include("vectors.jl")
1616
include("testsuite/testsuite.jl")
1717
include("jlbackend.jl")
1818

19-
export GPUArray, gpu_call, thread_blocks_heuristic
19+
export GPUArray, gpu_call, thread_blocks_heuristic, global_size
2020
export linear_index, @linearidx, @cartesianidx
2121

2222
end # module

src/abstract_gpu_interface.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
Abstraction over the GPU thread indexing functions.
33
Uses CUDA like names
44
=#
5-
for f in (:blockidx, :blockdim, :threadidx), sym in (:x, :y, :z)
6-
fname = Symbol(string(f, '_', sym))
7-
@eval $fname(state)::Cuint = error("Not implemented")
8-
@eval export $fname
5+
for sym in (:x, :y, :z)
6+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
7+
fname = Symbol(string(f, '_', sym))
8+
@eval $fname(state)::Cuint = error("Not implemented")
9+
@eval export $fname
10+
end
911
end
12+
1013
"""
1114
in CUDA terms `__synchronize`
1215
"""
@@ -15,11 +18,14 @@ function synchronize_threads(state)
1518
end
1619

1720
"""
18-
linear index in a GPU kernel
21+
linear index in a GPU kernel (equal to OpenCL.get_global_id)
1922
"""
2023
@inline function linear_index(state)
2124
Cuint((blockidx_x(state) - Cuint(1)) * blockdim_x(state) + threadidx_x(state))
2225
end
26+
@inline function global_size(state)
27+
griddim_x(state) * blockdim_x(state)
28+
end
2329

2430
"""
2531
Blocks until all operations are finished on `A`
@@ -36,10 +42,10 @@ function device(A::GPUArray)
3642
# makes it easier to write generic code that also works for AbstractArrays
3743
end
3844

39-
40-
@inline function synchronize_threads(state)
41-
CUDAnative.__syncthreads()
42-
end
45+
#
46+
# @inline function synchronize_threads(state)
47+
# CUDAnative.__syncthreads()
48+
# end
4349

4450
macro linearidx(A, statesym = :state)
4551
quote

src/abstractarray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,9 @@ function _reshape(A::GPUArray{T}, dims::Dims) where T
234234
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
235235
return unsafe_reinterpret(T, A, dims)
236236
end
237+
#ambig
238+
function _reshape(A::GPUArray{T, 1}, dims::Tuple{Int}) where T
239+
n = Base._length(A)
240+
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
241+
return unsafe_reinterpret(T, A, dims)
242+
end

src/base.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import Base: count, map!, permutedims!, cat_t, vcat, hcat
22
using Base: @pure
33

4-
count(pred, A::GPUArray) = Int(mapreduce(pred, +, Cuint(0), A))
5-
64
allequal(x) = true
75
allequal(x, y, z...) = x == y && allequal(y, z...)
86
function map!(f, y::GPUArray, xs::GPUArray...)
@@ -74,7 +72,7 @@ end
7472
(ind-l*indnext+f, _ind2sub(Base.tail(inds), indnext)...)
7573
end
7674

77-
@pure function gpu_sub2ind{N, T}(dims::NTuple{N}, I::NTuple{N, T})
75+
@pure function gpu_sub2ind{N, N2, T}(dims::NTuple{N}, I::NTuple{N2, T})
7876
Base.@_inline_meta
7977
_sub2ind(NTuple{N, T}(dims), T(1), T(1), I...)
8078
end

src/blas.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ for elty in (Float64, Float32)
3939
end
4040
end
4141

42+
Base.scale!(s::Real, X::GPUArray) = scale!(X, s)
4243
function Base.scale!(X::GPUArray{T}, s::Real) where T <: BLAS.BlasComplex
4344
R = typeof(real(zero(T)))
4445
buff = reinterpret(R, vec(X))
@@ -81,8 +82,8 @@ for elty in (Float32, Float64, Complex64, Complex128)
8182
if length(x) != length(y)
8283
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
8384
end
84-
blasmod = blas_module(A)
85-
blasmod.axpy!($elty(alpha), blasbuffer(dx), blasbuffer(dx))
85+
blasmod = blas_module(x)
86+
blasmod.axpy!($elty(alpha), blasbuffer(vec(x)), blasbuffer(vec(y)))
8687
y
8788
end
8889
end

src/broadcast.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ function broadcast!(f::typeof(identity), A::GPUArray, val::Number)
2727
gpu_call(const_kernel2, A, (A, valconv, Cuint(length(A))))
2828
A
2929
end
30+
@inline function broadcast_t(f, T::Type{Bool}, shape, it, A::GPUArrays.GPUArray, Bs::Vararg{Any,N}) where N
31+
C = similar(A, T, shape)
32+
keeps, Idefaults = map_newindexer(shape, A, Bs)
33+
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val{N}, it)
34+
return C
35+
end
36+
@inline function broadcast_t(f, T::Type{Bool}, shape, it, A::GPUArrays.GPUArray, B::GPUArrays.GPUArray, Bs::Vararg{Any,N}) where N
37+
C = similar(A, T, shape)
38+
Bs = (B, Bs...)
39+
keeps, Idefaults = map_newindexer(shape, A, Bs)
40+
_broadcast!(f, C, keeps, Idefaults, A, Bs, Val{N}, it)
41+
return C
42+
end
3043

3144
@inline function broadcast_t(
3245
f, ::Type{T}, shape, iter, A::GPUArray, Bs::Vararg{Any,N}
@@ -195,7 +208,13 @@ end
195208
@pure newindex(I, ilin, keep::Tuple{}, Idefault::Tuple{}, size::Tuple{}) = Cuint(1)
196209

197210
# optimize for 1D arrays
198-
@pure newindex(I::NTuple{1}, ilin, keep::NTuple{1}, Idefault, size) = ilin
211+
@pure function newindex(I::NTuple{1}, ilin, keep::NTuple{1}, Idefault, size)
212+
if Bool(keep[1])
213+
return ilin
214+
else
215+
return Idefault[1]
216+
end
217+
end
199218

200219
# differently shaped arrays
201220
@generated function newindex{N, T}(I, ilin::T, keep::NTuple{N}, Idefault, size)

src/construction.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: fill!, rand, similar, eye, zeros, fill
1+
import Base: fill!, similar, eye, zeros, fill
22

33

44
function fill(X::Type{<: GPUArray}, val, dims::Integer...)
@@ -32,10 +32,6 @@ function eye(T::Type{<: GPUArray}, dims::NTuple{2, Integer})
3232
res
3333
end
3434

35-
function rand{T <: GPUArray, ET}(::Type{T}, ::Type{ET}, size...)
36-
T(rand(ET, size...))
37-
end
38-
3935
(T::Type{<: GPUArray})(dims::Integer...) = T(dims)
4036
(T::Type{<: GPUArray{X} where X})(dims::NTuple{N, Integer}) where N = similar(T, eltype(T), dims)
4137

src/jlbackend.jl

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@ struct JLArray{T, N} <: GPUArray{T, N}
77
size::NTuple{N, Int}
88
end
99

10+
"""
11+
Thread group local memory
12+
"""
13+
immutable LocalMem{N, T}
14+
x::NTuple{N, Vector{T}}
15+
end
16+
1017
size(x::JLArray) = x.size
1118
pointer(x::JLArray) = pointer(x.data)
12-
to_device(x::JLArray) = x.data
13-
to_device(x::Tuple) = to_device.(x)
14-
to_device(x::RefValue{<: JLArray}) = RefValue(to_device(x[]))
15-
to_device(x) = x
19+
to_device(state, x::JLArray) = x.data
20+
to_device(state, x::Tuple) = to_device.(state, x)
21+
to_device(state, x::RefValue{<: JLArray}) = RefValue(to_device(state, x[]))
22+
to_device(state, x) = x
23+
# creates a `local` vector for each thread group
24+
to_device(state, x::LocalMemory{T}) where T = LocalMem(ntuple(i-> Vector{T}(x.size), blockdim_x(state)))
25+
26+
to_blocks(state, x) = x
27+
# unpacks local memory for each block
28+
to_blocks(state, x::LocalMem) = x.x[blockidx_x(state)]
1629

1730
function (::Type{JLArray{T, N}})(size::NTuple{N, Integer}) where {T, N}
1831
JLArray{T, N}(Array{T, N}(size), size)
@@ -47,31 +60,35 @@ end
4760

4861
mutable struct JLState{N}
4962
blockdim::NTuple{N, Int}
50-
threads::NTuple{N, Int}
63+
griddim::NTuple{N, Int}
5164

5265
blockidx::NTuple{N, Int}
5366
threadidx::NTuple{N, Int}
5467
end
5568

56-
5769
function gpu_call(f, A::JLArray, args::Tuple, blocks = nothing, threads = C_NULL)
5870
if blocks == nothing
5971
blocks, threads = thread_blocks_heuristic(length(A))
6072
elseif isa(blocks, Integer)
6173
blocks = (blocks,)
62-
if threads == C_NULL
63-
threads = (1,)
64-
end
74+
end
75+
if threads == C_NULL
76+
threads = (1,)
6577
end
6678
idx = ntuple(i-> 1, length(blocks))
6779
blockdim = ceil.(Int, blocks ./ threads)
68-
state = JLState(threads, threads, idx, idx)
69-
device_args = to_device.(args)
80+
state = JLState(threads, blockdim, idx, idx)
81+
device_args = to_device.(state, args)
82+
tasks = Vector{Task}(threads...)
7083
for blockidx in CartesianRange(blockdim)
7184
state.blockidx = blockidx.I
85+
block_args = to_blocks.(state, device_args)
7286
for threadidx in CartesianRange(threads)
73-
state.threadidx = threadidx.I
74-
f(state, device_args...)
87+
thread_state = JLState(state.blockdim, state.griddim, state.blockidx, threadidx.I)
88+
tasks[threadidx] = @async f(thread_state, block_args...)
89+
end
90+
for t in tasks
91+
wait(t)
7592
end
7693
end
7794
return
@@ -83,11 +100,22 @@ device(x::JLArray) = JLDevice()
83100
threads(dev::JLDevice) = 256
84101

85102

86-
@inline synchronize_threads(::JLState) = nothing
103+
@inline function synchronize_threads(::JLState)
104+
#=
105+
All threads are getting started asynchronously,so a yield will
106+
yield to the next execution of the same function, which should call yield
107+
at the exact same point in the program, leading to a chain of yields effectively syncing
108+
the tasks (threads).
109+
=#
110+
yield()
111+
return
112+
end
87113

88-
for f in (:blockidx, :blockdim, :threadidx), (i, sym) in enumerate((:x, :y, :z))
89-
fname = Symbol(string(f, '_', sym))
90-
@eval $fname(state::JLState) = Cuint(state.$f[$i])
114+
for (i, sym) in enumerate((:x, :y, :z))
115+
for f in (:blockidx, :blockdim, :threadidx, :griddim)
116+
fname = Symbol(string(f, '_', sym))
117+
@eval $fname(state::JLState) = Cuint(state.$f[$i])
118+
end
91119
end
92120

93121
blas_module(::JLArray) = Base.LinAlg.BLAS

src/linalg.jl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,48 @@
1-
# function transpose_kernel!(
2-
# state, At, A, width, height, A_local, ::Val{BLOCK}
3-
# ) where BLOCK
4-
#
5-
# ui1 = UInt32(1)
6-
# bidx_x = blockidx_x(state) - ui1
7-
# bidx_y = blockidx_y(state) - ui1
8-
# tidx_x = threadidx_x(state) - ui1
9-
# tidx_y = threadidx_y(state) - ui1
10-
#
11-
# base_idx_a = bidx_x * BLOCK + bidx_y * (BLOCK * width)
12-
# base_idx_a_t = bidx_y * BLOCK + bidx_x * (BLOCK * height)
13-
#
14-
# glob_idx_a = base_idx_a + tidx_x + width * tidx_y
15-
# glob_idx_a_t = base_idx_a_t + tidx_x + height * tidx_y
16-
#
17-
# A_local[tidx_y * BLOCK + tidx_x + ui1] = A[glob_idx_a + ui1]
18-
#
19-
# cli.barrier(cli.CLK_LOCAL_MEM_FENCE)
20-
# At[glob_idx_a_t + ui1] = A_local[tidx_x * BLOCK + tidx_y + ui1]
21-
# return
22-
# end
23-
#
24-
# function max_block_size(dev, h::Int, w::Int)
25-
# dim1, dim2 = GPUArrays.blocks(dev)[1:2]
26-
# wgsize = GPUArrays.threads(dev)
27-
# wglimit = floor(Int, sqrt(wgsize))
28-
# return gcd(dim1, dim2, h, w, wglimit)
29-
# end
30-
#
31-
# function Base.transpose!{T}(At::GPUArray{T, 2}, A::GPUArray{T, 2})
32-
# dev = GPUArrays.device(A)
33-
# block_size = max_block_size(dev, size(A)...)
34-
# outsize = UInt32.(size(At))
35-
# lmem = GPUArrays.LocalMemory{T}(block_size * (block_size + 1))
36-
# args = (At, A, outsize..., lmem, Val{block_size}())
37-
# gpu_call(transpose_kernel!, At, args, (block_size, block_size))
38-
# At
39-
# end
1+
function transpose_kernel!(
2+
state, At, A, width, height, A_local, ::Val{BLOCK}
3+
) where BLOCK
4+
5+
ui1 = UInt32(1)
6+
bidx_x = blockidx_x(state) - ui1
7+
bidx_y = blockidx_y(state) - ui1
8+
tidx_x = threadidx_x(state) - ui1
9+
tidx_y = threadidx_y(state) - ui1
10+
11+
base_idx_a = bidx_x * BLOCK + bidx_y * (BLOCK * width)
12+
base_idx_a_t = bidx_y * BLOCK + bidx_x * (BLOCK * height)
13+
14+
glob_idx_a = base_idx_a + tidx_x + width * tidx_y
15+
glob_idx_a_t = base_idx_a_t + tidx_x + height * tidx_y
16+
17+
A_local[tidx_y * BLOCK + tidx_x + ui1] = A[glob_idx_a + ui1]
18+
synchronize_threads(state)
19+
At[glob_idx_a_t + ui1] = A_local[tidx_x * BLOCK + tidx_y + ui1]
20+
return
21+
end
22+
23+
function max_block_size(dev, h::Int, w::Int)
24+
dim1, dim2 = GPUArrays.blocks(dev)[1:2]
25+
wgsize = GPUArrays.threads(dev)
26+
wglimit = floor(Int, sqrt(wgsize))
27+
return gcd(dim1, dim2, h, w, wglimit)
28+
end
29+
30+
function Base.transpose!{T}(At::GPUArray{T, 2}, A::GPUArray{T, 2})
31+
dev = GPUArrays.device(A)
32+
block_size = max_block_size(dev, size(A)...)
33+
outsize = UInt32.(size(At))
34+
lmem = GPUArrays.LocalMemory{T}(block_size * (block_size + 1))
35+
args = (At, A, outsize..., lmem, Val{block_size}())
36+
gpu_call(transpose_kernel!, At, args, (block_size, block_size))
37+
At
38+
end
4039

4140
function genperm(I::NTuple{N}, perm::NTuple{N}) where N
4241
ntuple(d-> I[perm[d]], Val{N})
4342
end
4443

4544
function Base.permutedims!(dest::GPUArray, src::GPUArray, perm)
45+
perm = Cuint.((perm...,))
4646
gpu_call(dest, (dest, src, perm)) do state, dest, src, perm
4747
I = @cartesianidx dest state
4848
@inbounds dest[I...] = src[genperm(I, perm)...]

0 commit comments

Comments
 (0)