Skip to content

Commit 560c634

Browse files
committed
sophisticated device selection
1 parent c16aefa commit 560c634

File tree

15 files changed

+186
-159
lines changed

15 files changed

+186
-159
lines changed

src/abstractarray.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ end
2020
#=
2121
Interface for accessing the lower level
2222
=#
23-
2423
buffer(A::AbstractAccArray) = A.buffer
2524
context(A::AbstractAccArray) = A.context
2625
default_buffer_type(typ, context) = error("Found unsupported context: $context")
@@ -66,7 +65,6 @@ function Base.similar{N, ET}(x::AbstractAccArray, ::Type{ET}, sz::NTuple{N, Int}
6665
end
6766

6867

69-
using Compat.TypeUtils
7068
function Base.similar{T <: GPUArray, ET, N}(
7169
::Type{T}, ::Type{ET}, sz::NTuple{N, Int};
7270
context::Context = current_context(), kw_args...
@@ -77,9 +75,6 @@ end
7775

7876

7977

80-
81-
82-
8378
#=
8479
Host to Device data transfers
8580
=#

src/backends/backends.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,42 @@ end
5656

5757
################################
5858
# Device selection functions for e.g. devices(filterfuncs)
59+
is_gpu(ctx::Context) = is_gpu(ctx.device)
60+
is_cpu(ctx::Context) = is_cpu(ctx.device)
61+
has_atleast(ctx::Context, attribute, value) = has_atleast(ctx.device, attribute, value)
62+
5963
is_gpu(device) = false
6064
is_cpu(device) = false
6165
has_atleast(device, attribute, value) = attribute(ctx_or_device) >= value
6266

67+
68+
#################################
69+
# Context filter functions
70+
# Works for context objects as well but is overloaded in the backends
71+
is_opencl(ctx::Symbol) = ctx == :opencl
72+
is_cudanative(ctx::Symbol) = ctx == :cudanative
73+
is_julia(ctx::Symbol) = ctx == :julia
74+
is_opengl(ctx::Symbol) = ctx == :opengl
75+
76+
is_opencl(ctx) = false
77+
is_cudanative(ctx) = false
78+
is_julia(ctx) = false
79+
is_opengl(ctx) = false
80+
81+
6382
"""
6483
Creates a new context from `device` without caching the resulting context.
6584
"""
6685
function new_context(device)
6786
error("Device $device not supported")
6887
end
6988

70-
# BLAS support
71-
hasblas(x) = false
72-
include("blas.jl")
73-
include("supported_backends.jl")
74-
include("shared.jl")
89+
"""
90+
Resets a context freeing all resources and creating a new context.
91+
"""
92+
function reset!(context)
93+
error("Context $context not supported")
94+
end
7595

7696
function backend_module(sym::Symbol)
7797
if sym in supported_backends()
@@ -89,17 +109,27 @@ function backend_module(sym::Symbol)
89109
end
90110
end
91111
function init(sym::Symbol, args...; kw_args...)
92-
backend_module(sym).init(args...; kw_args...)
112+
mod = backend_module(sym)
113+
setbackend!(mod)
114+
init(args...; kw_args...)
93115
end
94116

95117
function init(filterfuncs::Function...; kw_args...)
96118
devices = available_devices(filterfuncs...)
97119
if isempty(devices)
98120
error("No device found for: $(join(string.(filterfuncs), " "))")
99121
end
100-
current_backend().init(first(devices))
122+
init(first(devices))
101123
end
102124

125+
# BLAS support
126+
hasblas(x) = false
127+
include("blas.jl")
128+
include("supported_backends.jl")
129+
include("shared.jl")
130+
131+
132+
103133
active_backends() = backend_module.(supported_backends())
104134

105135
const global_current_backend = Ref{Module}(default_backend())
@@ -182,3 +212,6 @@ function forall_devices(f, filterfuncs...)
182212
f(ctx)
183213
end
184214
end
215+
216+
217+
export is_cudanative, is_julia, is_opencl

src/backends/cudanative/cudanative.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ..GPUArrays, CUDAnative, StaticArrays
44

55
import CUDAdrv, CUDArt #, CUFFT
66

7-
import GPUArrays: buffer, create_buffer, acc_mapreduce
7+
import GPUArrays: buffer, create_buffer, acc_mapreduce, is_cudanative
88
import GPUArrays: Context, GPUArray, context, linear_index, gpu_call
99
import GPUArrays: blas_module, blasbuffer, is_blas_supported, hasblas, init
1010
import GPUArrays: default_buffer_type, broadcast_index, is_fft_supported, unsafe_reinterpret
@@ -23,7 +23,7 @@ immutable CUContext <: Context
2323
ctx::CUDAdrv.CuContext
2424
device::CUDAdrv.CuDevice
2525
end
26-
26+
is_cudanative(ctx::CUContext) = true
2727
function Base.show(io::IO, ctx::CUContext)
2828
println(io, "CUDAnative context with:")
2929
device_summary(io, ctx.device)
@@ -52,7 +52,7 @@ const CUArray{T, N, B} = GPUArray{T, N, B, CUContext} #, GLArrayImg{T, N}}
5252
const CUArrayBuff{T, N} = CUArray{T, N, CUDAdrv.CuArray{T, N}}
5353

5454

55-
global init, all_contexts, current_context, current_device
55+
global all_contexts, current_context, current_device
5656

5757
let contexts = Dict{CUDAdrv.CuDevice, CUContext}(), active_device = CUDAdrv.CuDevice[]
5858

@@ -63,8 +63,14 @@ let contexts = Dict{CUDAdrv.CuDevice, CUContext}(), active_device = CUDAdrv.CuDe
6363
end
6464
active_device[]
6565
end
66-
current_context() = contexts[current_device()]
67-
function init(dev::CUDAdrv.CuDevice = current_device())
66+
function current_context()
67+
dev = current_device()
68+
get!(contexts, dev) do
69+
new_context(dev)
70+
end
71+
end
72+
73+
function GPUArrays.init(dev::CUDAdrv.CuDevice)
6874
if isempty(active_device)
6975
push!(active_device, dev)
7076
else
@@ -93,7 +99,7 @@ function reset!(context::CUContext)
9399
return
94100
end
95101

96-
function new_context(dev::CUDAdrv.CuDevice = current_device())
102+
function new_context(dev::CUDAdrv.CuDevice)
97103
cuctx = CUDAdrv.CuContext(dev)
98104
ctx = CUContext(cuctx, dev)
99105
CUDAdrv.activate(cuctx)
@@ -241,14 +247,14 @@ function (f::CUFunction{F}){F <: CUDAdrv.CuFunction, T, N}(A::CUArray{T, N}, arg
241247
)
242248
end
243249

244-
function gpu_call{T, N}(f::Function, A::CUArray{T, N}, args, globalsize = length(A), localsize = nothing)
250+
function gpu_call{T, N}(f::Function, A::CUArray{T, N}, args::Tuple, globalsize = length(A), localsize = nothing)
245251
blocks, thread = thread_blocks_heuristic(globalsize)
246252
args = map(unpack_cu_array, args)
247253
#cu_kernel, rewritten = CUDAnative.rewrite_for_cudanative(kernel, map(typeof, args))
248254
#println(CUDAnative.@code_typed kernel(args...))
249255
@cuda (blocks, thread) f(0f0, args...)
250256
end
251-
function gpu_call{T, N}(f::Tuple{String, Symbol}, A::CUArray{T, N}, args, globalsize = size(A), localsize = nothing)
257+
function gpu_call{T, N}(f::Tuple{String, Symbol}, A::CUArray{T, N}, args::Tuple, globalsize = size(A), localsize = nothing)
252258
func = CUFunction(A, f, args...)
253259
# TODO cache
254260
func(A, args) # TODO pass through local/global size

src/backends/julia/julia.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,48 @@ import GPUArrays: AbstractAccArray, AbstractSampler, acc_mapreduce, gpu_call
88
import GPUArrays: hasblas, blas_module, blasbuffer, default_buffer_type
99
import GPUArrays: unsafe_reinterpret, broadcast_index, linear_index
1010
import GPUArrays: is_cpu, name, threads, blocks, global_memory
11+
import GPUArrays: new_context, init, free_global_memory
1112

1213
import Base.Threads: @threads
1314

1415
immutable JLContext <: Context
1516
nthreads::Int
1617
end
18+
# TODO,one could have multiple CPUs ?
19+
immutable JLDevice <: Context
20+
index::Int
21+
end
22+
1723

18-
global current_context, make_current, init
19-
let contexts = JLContext[]
20-
all_contexts() = copy(contexts)::Vector{JLContext}
21-
current_context() = last(contexts)::JLContext
22-
function init()
23-
ctx = JLContext(Base.Threads.nthreads())
24-
GPUArrays.make_current(ctx)
25-
push!(contexts, ctx)
24+
global all_contexts, current_context, current_device
25+
let contexts = Dict{JLDevice, JLContext}(), active_device = JLDevice[]
26+
all_contexts() = values(contexts)
27+
function current_device()
28+
if isempty(active_device)
29+
push!(active_device, JLDevice(0))
30+
end
31+
active_device[]
32+
end
33+
current_context() = contexts[current_device()]
34+
function GPUArrays.init(dev::JLDevice)
35+
if isempty(active_device)
36+
push!(active_device, dev)
37+
else
38+
active_device[] = dev
39+
end
40+
ctx = get!(()-> new_context(dev), contexts, dev)
2641
ctx
2742
end
2843
end
2944

30-
immutable JLDevice end
31-
45+
new_context(dev::JLDevice) = JLContext(Threads.nthreads())
3246
threads(x::JLDevice) = Base.Threads.nthreads()
3347
global_memory(x::JLDevice) = Sys.total_memory()
3448
free_global_memory(x::JLDevice) = Sys.free_memory()
35-
name(x::JLDevice) = Sys.cpu_info()[1].model # TODO,one could have multiple CPUs ?
49+
name(x::JLDevice) = Sys.cpu_info()[1].model
3650
is_cpu(::JLDevice) = true
3751

38-
devices() = (JLDevice(),)
52+
devices() = (JLDevice(0),)
3953

4054

4155
immutable Sampler{T, N, Buffer} <: AbstractSampler{T, N}
@@ -109,8 +123,8 @@ Base.@propagate_inbounds Base.setindex!{T, N}(A::JLArray{T, N}, val, i::Integer)
109123
Base.IndexStyle{T, N}(::Type{JLArray{T, N}}) = IndexLinear()
110124

111125
function Base.show(io::IO, ctx::JLContext)
112-
cpu = Sys.cpu_info()
113-
print(io, "JLContext $(cpu[1].model) with $(ctx.nthreads) threads")
126+
println("Threaded Julia Context with:")
127+
GPUArrays.device_summary(io, JLDevice(0))
114128
end
115129
##############################################
116130
# Implement BLAS interface

src/backends/opencl/opencl.jl

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,34 @@ using OpenCL: cl
66

77
using ..GPUArrays, StaticArrays
88

9-
import GPUArrays: buffer, create_buffer, acc_mapreduce, mapidx
10-
import GPUArrays: Context, GPUArray, context, linear_index, free
9+
import GPUArrays: buffer, create_buffer, acc_mapreduce, mapidx, is_opencl
10+
import GPUArrays: Context, GPUArray, context, linear_index, free, init
1111
import GPUArrays: blasbuffer, blas_module, is_blas_supported, is_fft_supported
1212
import GPUArrays: synchronize, hasblas, LocalMemory, AccMatrix, AccVector, gpu_call
13-
import GPUArrays: default_buffer_type, broadcast_index, unsafe_reinterpret
13+
import GPUArrays: default_buffer_type, broadcast_index, unsafe_reinterpret, reset!
1414
import GPUArrays: is_gpu, is_cpu, name, threads, blocks, global_memory, local_memory
15+
using GPUArrays: device_summary
1516

1617
using Transpiler
1718
import Transpiler: cli, cli.get_global_id
1819

1920

20-
21-
immutable CLContext <: Context
21+
type CLContext <: Context
2222
device::cl.Device
2323
context::cl.Context
2424
queue::cl.CmdQueue
25-
function CLContext(device_type = nothing)
26-
device = if device_type == nothing
27-
devlist = cl.devices(:gpu)
28-
dev = if isempty(devlist)
29-
devlist = cl.devices(:cpu)
30-
if isempty(devlist)
31-
error("no device found to be supporting opencl")
32-
else
33-
first(devlist)
34-
end
35-
else
36-
first(devlist)
37-
end
38-
dev
39-
else
40-
# if device type supplied by user, assume it's actually existant!
41-
devlist = cl.devices(device_type)
42-
if isempty(devlist)
43-
error("Can't find OpenCL device for $device_type")
44-
end
45-
first(devlist)
46-
end
25+
function CLContext(device::cl.Device)
4726
ctx = cl.Context(device)
4827
queue = cl.CmdQueue(ctx)
4928
new(device, ctx, queue)
5029
end
5130
end
31+
32+
is_opencl(ctx::CLContext) = true
33+
5234
function Base.show(io::IO, ctx::CLContext)
53-
name = replace(ctx.device[:name], r"\s+", " ")
54-
print(io, "CLContext: $name")
35+
println(io, "OpenCL context with:")
36+
device_summary(io, ctx.device)
5537
end
5638

5739

@@ -69,27 +51,52 @@ global_memory(dev::cl.Device) = cl.info(dev, :global_mem_size) |> Int
6951
local_memory(dev::cl.Device) = cl.info(dev, :local_mem_size) |> Int
7052

7153

72-
73-
global init, all_contexts, current_context
74-
let contexts = CLContext[]
75-
all_contexts() = copy(contexts)::Vector{CLContext}
76-
current_context() = last(contexts)::CLContext
77-
function init(;device_type = nothing, ctx = nothing)
78-
context = if ctx == nothing
79-
if isempty(contexts)
80-
CLContext(device_type)
81-
else
82-
current_context()
83-
end
54+
global all_contexts, current_context, current_device
55+
let contexts = Dict{cl.Device, CLContext}(), active_device = cl.Device[]
56+
all_contexts() = values(contexts)
57+
function current_device()
58+
if isempty(active_device)
59+
push!(active_device, CUDAnative.default_device[])
60+
end
61+
active_device[]
62+
end
63+
function current_context()
64+
dev = current_device()
65+
get!(contexts, dev) do
66+
new_context(dev)
67+
end
68+
end
69+
function GPUArrays.init(dev::cl.Device)
70+
if isempty(active_device)
71+
push!(active_device, dev)
8472
else
85-
ctx
73+
active_device[] = dev
74+
end
75+
ctx = get!(()-> new_context(dev), contexts, dev)
76+
ctx
77+
end
78+
79+
function destroy!(context::CLContext)
80+
# don't destroy primary device context
81+
dev = context.device
82+
if haskey(contexts, dev) && contexts[dev] == context
83+
error("Trying to destroy primary device context which is prohibited. Please use reset!(context)")
8684
end
87-
GPUArrays.make_current(context)
88-
push!(contexts, context)
89-
context
85+
finalize(context.ctx)
86+
return
9087
end
9188
end
9289

90+
function reset!(context::CLContext)
91+
device = context.device
92+
finalize(context.context)
93+
context.context = cl.Context(device)
94+
context.queue = cl.CmdQueue(context.context)
95+
return
96+
end
97+
98+
new_context(dev::cl.Device) = CLContext(dev)
99+
93100
const CLArray{T, N} = GPUArray{T, N, B, CLContext} where B <: cl.Buffer
94101

95102
include("compilation.jl")
@@ -233,7 +240,7 @@ function thread_blocks_heuristic(len::Integer)
233240
end
234241

235242

236-
function gpu_call{T, N}(f, A::CLArray{T, N}, args, globalsize = length(A), localsize = nothing)
243+
function gpu_call{T, N}(f, A::CLArray{T, N}, args::Tuple, globalsize = length(A), localsize = nothing)
237244
ctx = GPUArrays.context(A)
238245
_args = if !isa(f, Tuple{String, Symbol})
239246
(0f0, args...)# include "state"

0 commit comments

Comments
 (0)