Skip to content
This repository was archived by the owner on Sep 27, 2021. It is now read-only.

Commit 881efc9

Browse files
committed
fixes for CuArray changes
1 parent 0da97eb commit 881efc9

File tree

6 files changed

+83
-58
lines changed

6 files changed

+83
-58
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ sudo: false
33

44
os:
55
- linux
6-
- osx
6+
#- osx
77

88
julia:
99
- 0.6

src/CLArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ include("array.jl")
1414
include("ondevice.jl")
1515
include("device.jl")
1616
include("context.jl")
17-
include("intrinsics.jl")
1817
include("compilation.jl")
1918
include("3rdparty.jl")
2019

src/compilation.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
using OpenCL: cl
22
using Transpiler: CLMethod, EmptyStruct
33
using Sugar
4-
import GPUArrays: gpu_call, linear_index
4+
import GPUArrays: _gpu_call, linear_index
55
using Transpiler: CLMethod
66
using Sugar: method_nargs, getslots!, isintrinsic, getcodeinfo!, sugared
77
using Sugar: returntype, type_ast, rewrite_ast, newslot!, to_tuple
88
using Sugar: isfunction
99

1010
using Base: tail
1111

12-
function gpu_call(f, A::CLArray, args::Tuple, blocks = nothing, thread = C_NULL)
12+
13+
14+
function _gpu_call(f, A::CLArray, args::Tuple, blocks_threads::Tuple{T, T}) where T <: NTuple{N, Integer} where N
1315
ctx = context(A)
1416
_args = (KernelState(), args...) # CLArrays "state"
1517
clfunc = CLFunction(f, _args, ctx)
16-
if blocks == nothing
17-
blocks, thread = thread_blocks_heuristic(length(A))
18-
elseif isa(blocks, Integer)
19-
blocks = (blocks,)
20-
end
21-
clfunc(_args, blocks, thread)
18+
blocks, threads = blocks_threads
19+
global_size = blocks .* threads
20+
clfunc(_args, global_size, threads)
2221
end
2322

2423

src/intrinsics.jl

Lines changed: 0 additions & 36 deletions
This file was deleted.

src/ondevice.jl

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
import Base: setindex!, getindex, size, IndexStyle, next, done, start, sum, eltype
1+
import Base: setindex!, getindex, size, IndexStyle, sum, eltype
22
using Base: IndexLinear
3+
34
using Transpiler.cli: LocalPointer
5+
using Transpiler.cli: get_local_id, get_global_id, barrier, CLK_LOCAL_MEM_FENCE
6+
using Transpiler.cli: get_local_size, get_global_size, get_group_id, get_num_groups
7+
import GPUArrays: synchronize, synchronize_threads, device, global_size, linear_index
8+
49
import GPUArrays: LocalMemory
10+
using GPUArrays: AbstractDeviceArray
511

612

713
"""
814
Array type on the device
915
"""
10-
struct DeviceArray{T, N, Ptr} <: AbstractArray{T, N}
16+
struct DeviceArray{T, N, Ptr} <: AbstractDeviceArray{T, N}
1117
ptr::Ptr
1218
size::NTuple{N, Cuint}
1319
end
@@ -27,11 +33,7 @@ const LocalArray{T, N} = DeviceArray{T, N, LocalPointer{T}}
2733

2834
const OnDeviceArray{T, N} = Union{GlobalArray{T, N}, LocalArray{T, N}} # Variant on the device containing the correct pointer
2935

30-
size(x::OnDeviceArray) = x.size
31-
IndexStyle(::OnDeviceArray) = IndexLinear()
32-
start(x::OnDeviceArray) = Cuint(1)
33-
next(x::OnDeviceArray, state::Cuint) = x[state], state + Cuint(1)
34-
done(x::OnDeviceArray, state::Cuint) = state > length(x)
36+
size(x::DeviceArray) = x.size
3537

3638
getindex(x::OnDeviceArray, ilin::Integer) = x.ptr[ilin]
3739
function getindex(x::OnDeviceArray{T, N}, i::Vararg{Integer, N}) where {T, N}
@@ -85,10 +87,44 @@ device_type(x::T) where T <: Tuple = Tuple{device_type.(x)...}
8587
end
8688

8789

88-
function sum(A::CLArrays.DeviceArray{T}) where T
89-
acc = zero(T)
90-
for elem in A
91-
acc += elem
90+
#synchronize
91+
function synchronize(x::CLArray)
92+
cl.finish(global_queue(x)) # TODO figure out the diverse ways of synchronization
93+
end
94+
95+
96+
immutable KernelState
97+
empty::Int32
98+
KernelState() = new(Int32(0))
99+
end
100+
101+
for (i, sym) in enumerate((:x, :y, :z))
102+
for (f, fcl, isidx) in (
103+
(:blockidx, get_group_id, true),
104+
(:blockdim, get_local_size, false),
105+
(:threadidx, get_local_id, true),
106+
(:griddim, get_num_groups, false)
107+
)
108+
109+
fname = Symbol(string(f, '_', sym))
110+
if isidx
111+
@eval GPUArrays.$fname(::KernelState)::Cuint = $fcl($(i-1)) + Cuint(1)
112+
else
113+
@eval GPUArrays.$fname(::KernelState)::Cuint = $fcl($(i-1))
114+
end
92115
end
93-
acc
116+
end
117+
118+
global_size(state::KernelState) = get_global_size(0)
119+
linear_index(state::KernelState) = get_global_id(0) + Cuint(1)
120+
121+
synchronize_threads(::KernelState) = cli.barrier(CLK_LOCAL_MEM_FENCE)
122+
123+
LocalMemory(state::KernelState, ::Type{T}, ::Val{N}, ::Val{C}) where {T, N, C} = Transpiler.cli.LocalPointer{T}()
124+
125+
function (::Type{AbstractDeviceArray})(ptr::PtrT, shape::Vararg{Integer, N}) where PtrT <: Transpiler.cli.LocalPointer{T} where {T, N}
126+
DeviceArray{T, N, PtrT}(ptr, shape)
127+
end
128+
function (::Type{AbstractDeviceArray})(ptr::PtrT, shape::NTuple{N, Integer}) where PtrT <: Transpiler.cli.LocalPointer{T} where {T, N}
129+
DeviceArray{T, N, PtrT}(ptr, shape)
94130
end

test/runtests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,32 @@ for dev in CLArrays.devices()
2020
CLArrays.init(dev)
2121
@testset "Device: $dev" begin
2222
TestSuite.run_tests(CLArray)
23+
24+
@testset "muladd & abs" begin
25+
a = rand(Float32, 32) - 0.5f0
26+
A = CLArray(a)
27+
x = abs.(A)
28+
@test Array(x) == abs.(a)
29+
y = muladd.(A, 2f0, x)
30+
@test Array(y) == muladd(a, 2f0, abs.(a))
31+
###########
32+
# issue #20
33+
34+
against_base(a-> abs.(a), CLArray{Float32}, (10, 10))
35+
end
2336
end
2437
end
38+
39+
40+
# The above is equal to:
41+
# Typ = CuArray
42+
# GPUArrays.allowslow(false)
43+
# TestSuite.run_gpuinterface(Typ)
44+
# TestSuite.run_base(Typ)
45+
# TestSuite.run_blas(Typ)
46+
# TestSuite.run_broadcasting(Typ)
47+
# TestSuite.run_construction(Typ)
48+
# TestSuite.run_fft(Typ)
49+
# TestSuite.run_linalg(Typ)
50+
# TestSuite.run_mapreduce(Typ)
51+
# TestSuite.run_indexing(Typ)

0 commit comments

Comments
 (0)