Skip to content

Commit 72a0637

Browse files
committed
fixing lots of Knet tests
1 parent c71a443 commit 72a0637

File tree

4 files changed

+83
-9
lines changed

4 files changed

+83
-9
lines changed

src/blas.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,28 @@ for elty in (Float32, Float64, Complex64, Complex128)
6262
elseif trans == 'T' && (length(X) != m || length(Y) != n)
6363
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
6464
end
65-
ctx = context(A)
66-
blasmod = blas_module(ctx)
65+
blasmod = blas_module(A)
6766
blasmod.gemv!(
6867
trans, alpha,
69-
blasbuffer(ctx, A), blasbuffer(ctx, X), beta, blasbuffer(ctx, Y)
68+
blasbuffer(A), blasbuffer(X), beta, blasbuffer(Y)
7069
)
7170
Y
7271
end
7372
end
7473
end
74+
75+
76+
for elty in (Float32, Float64, Complex64, Complex128)
77+
@eval begin
78+
function Base.BLAS.axpy!(
79+
alpha::Number, x::GPUArray{$elty}, y::GPUArray{$elty}
80+
)
81+
if length(x) != length(y)
82+
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
83+
end
84+
blasmod = blas_module(A)
85+
blasmod.axpy!($elty(alpha), blasbuffer(dx), blasbuffer(dx))
86+
y
87+
end
88+
end
89+
end

src/jlbackend.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ function gpu_call(f, A::JLArray, args::Tuple, blocks = nothing, threads = C_NULL
7878
end
7979

8080
# "intrinsics"
81+
struct JLDevice end
82+
device(x::JLArray) = JLDevice()
83+
threads(dev::JLDevice) = 256
84+
8185

8286
@inline synchronize_threads(::JLState) = nothing
8387

src/mapreduce.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#############################
22
# reduce
3+
# functions in base implemented with a direct loop need to be overloaded to use mapreduce
4+
any(pred, A::GPUArray) = Bool(mapreduce(pred, |, Cint(0), (u)))
5+
count(pred, A::GPUArray) = Int(mapreduce(pred, +, Cuint(0), A))
6+
37

48
# hack to get around of fetching the first element of the GPUArray
59
# as a startvalue, which is a bit complicated with the current reduce implementation
@@ -28,8 +32,7 @@ function Base.mapreduce{T, N}(f::Function, op::Function, A::GPUArray{T, N})
2832
v0 = startvalue(op, OT) # TODO do this better
2933
mapreduce(f, op, v0, A)
3034
end
31-
32-
35+
function acc_mapreduce end
3336
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number)
3437
acc_mapreduce(f, op, v0, A, (B, C))
3538
end
@@ -65,7 +68,3 @@ function Base._mapreducedim!(f, op, R::GPUArray, A::GPUArray)
6568
gpu_call(mapreducedim_kernel, R, (f, op, R, A, Cuint(slice_size), Cuint.(size(A)), Cuint(dim)))
6669
return R
6770
end
68-
69-
70-
any(pred, A::GPUArray) = Bool(mapreduce(isnan, |, Cint(0), (u)))
71-
count(pred, A::GPUArray) = Int(mapreduce(pred, +, Cuint(0), A))

src/random.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using GPUArrays
2+
function TausStep(z::Unsigned, S1::Integer, S2::Integer, S3::Integer, M::Unsigned)
3+
b = (((z << S1) z) >> S2)
4+
return (((z & M) << S3) b)
5+
end
6+
7+
LCGStep(z::Unsigned, A::Unsigned, C::Unsigned) = A * z + C
8+
9+
make_rand_num(::Type{Float64}, tmp) = 2.3283064365387e-10 * Float64(tmp)
10+
make_rand_num(::Type{Float32}, tmp) = 2.3283064f-10 * Float32(tmp)
11+
12+
13+
function next_rand(::Type{FT}, state::NTuple{4, T}) where {FT, T <: Unsigned}
14+
state = (
15+
TausStep(state[1], Cint(13), Cint(19), Cint(12), T(4294967294)),
16+
TausStep(state[2], Cint(2), Cint(25), Cint(4), T(4294967288)),
17+
TausStep(state[3], Cint(3), Cint(11), Cint(17), T(4294967280)),
18+
LCGStep(state[4], T(1664525), T(1013904223))
19+
)
20+
tmp = (state[1] state[2] state[3] state[4])
21+
return (
22+
state,
23+
make_rand_num(FT, tmp)
24+
)
25+
end
26+
27+
function gpu_rand(::Type{T}, state, randstate::AbstractVector{NTuple{4, Cuint}}) where T
28+
threadid = GPUArrays.threadidx_x(state)
29+
stateful_rand = next_rand(T, randstate[threadid])
30+
randstate[threadid] = stateful_rand[1]
31+
return stateful_rand[2]
32+
end
33+
34+
global cached_state, clear_cache
35+
let rand_state_dict = Dict()
36+
clear_cache() = (empty!(rand_state_dict); return)
37+
function cached_state(x)
38+
dev = GPUArrays.device(x)
39+
get!(rand_state_dict, dev) do
40+
N = GPUArrays.threads(dev)
41+
res = similar(x, NTuple{4, Cuint}, N)
42+
copy!(res, [ntuple(i-> rand(Cuint), 4) for i=1:N])
43+
res
44+
end
45+
end
46+
end
47+
function Base.rand!(A::GPUArray{T}) where T <: AbstractFloat
48+
rstates = cached_state(A)
49+
gpu_call(A, (rstates, A,)) do state, randstates, a
50+
idx = linear_index(state)
51+
idx > length(a) && return
52+
a[idx] = gpu_rand(T, state, randstates)
53+
return
54+
end
55+
A
56+
end

0 commit comments

Comments
 (0)