Skip to content

Commit c16aefa

Browse files
committed
clean up device handling
1 parent d1c7161 commit c16aefa

File tree

5 files changed

+193
-63
lines changed

5 files changed

+193
-63
lines changed

src/backends/backends.jl

Lines changed: 97 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
global current_context, make_current
2+
23
function default_backend()
34
if is_backend_supported(:cudanative)
45
CUBackend
@@ -9,22 +10,7 @@ function default_backend()
910
end
1011
end
1112

12-
let compute_contexts = Context[]
13-
function current_context()
14-
if isempty(compute_contexts)
15-
default_backend().init()
16-
end
17-
last(compute_contexts)
18-
end
19-
all_contexts() = copy(compute_contexts)
20-
function make_current(ctx)
21-
idx = findfirst(compute_contexts, ctx)
22-
if idx != 0
23-
splice!(compute_contexts, idx) # remove
24-
end
25-
push!(compute_contexts, ctx)
26-
end
27-
end
13+
2814
#interface
2915
function create_buffer(ctx, array) end
3016
"""
@@ -34,38 +20,61 @@ function synchronize(A::AbstractArray)
3420
# fallback is a noop, for backends not needing synchronization. This
3521
# makes it easier to write generic code that also works for AbstractArrays
3622
end
23+
3724
"""
3825
`A` must be a gpu Array and will help to dispatch to the correct GPU backend
3926
and can supply queues and contexts.
4027
Calls `f` on args on the GPU, falls back to a normal call if there is no backend.
4128
"""
42-
function gpu_call(A::AbstractArray, f, args, worksize, localsize = nothing)
29+
function gpu_call(f, A::AbstractArray, args::Tuple, worksize = length(A), localsize = nothing)
4330
f(args...)
4431
end
4532

46-
function free(x::AbstractArray)
33+
free(x::AbstractArray) = nothing
4734

48-
end
4935
#=
5036
Functions to select contexts
5137
=#
5238

53-
is_gpu(ctx) = false
54-
is_cpu(ctx) = false
55-
is_opencl(ctx) = false
56-
is_cudanative(ctx) = false
57-
is_julia(ctx) = false
58-
is_opengl(ctx) = false
59-
has_atleast(ctx, attribute, value) = error("has_atleast not implemented yet")
39+
threads(device) = 0
40+
blocks(device) = 0
41+
global_memory(device) = 0
42+
free_global_memory(device) = NaN
43+
local_memory(device) = 0
44+
name(device) = "Undefined"
45+
46+
function device_summary(io::IO, device)
47+
println(io, "Device: ", name(device))
48+
for (n, f) in (:threads => threads, :blocks => blocks)
49+
@printf(io, "%19s: %s\n", string(n), string(f(device)))
50+
end
51+
for (n, f) in (:global_memory => global_memory, :free_global_memory => free_global_memory, :local_memory => local_memory)
52+
@printf(io, "%19s: %f mb\n", string(n), f(device) / 10^6)
53+
end
54+
return
55+
end
56+
57+
################################
58+
# Device selection functions for e.g. devices(filterfuncs)
59+
is_gpu(device) = false
60+
is_cpu(device) = false
61+
has_atleast(device, attribute, value) = attribute(ctx_or_device) >= value
62+
63+
"""
64+
Creates a new context from `device` without caching the resulting context.
65+
"""
66+
function new_context(device)
67+
error("Device $device not supported")
68+
end
6069

6170
# BLAS support
6271
hasblas(x) = false
6372
include("blas.jl")
6473
include("supported_backends.jl")
6574
include("shared.jl")
6675

67-
function to_backend_module(backend::Symbol)
68-
if backend in supported_backends()
76+
function backend_module(sym::Symbol)
77+
if sym in supported_backends()
6978
if sym == :julia
7079
JLBackend
7180
elseif sym == :cudanative
@@ -82,17 +91,69 @@ end
8291
function init(sym::Symbol, args...; kw_args...)
8392
backend_module(sym).init(args...; kw_args...)
8493
end
94+
8595
function init(filterfuncs::Function...; kw_args...)
86-
init_from_device(first(devices(filterfuncs...)))
96+
devices = available_devices(filterfuncs...)
97+
if isempty(devices)
98+
error("No device found for: $(join(string.(filterfuncs), " "))")
99+
end
100+
current_backend().init(first(devices))
101+
end
102+
103+
active_backends() = backend_module.(supported_backends())
104+
105+
const global_current_backend = Ref{Module}(default_backend())
106+
107+
current_backend() = global_current_backend[]
108+
current_device() = current_backend().current_device()
109+
current_context() = current_backend().current_context()
110+
111+
"""
112+
Sets the current backend to be used globally. Accepts the symbols:
113+
:cudanative, :opencl, :julia.
114+
"""
115+
function setbackend!(backend::Symbol)
116+
setbackend!(backend_module(backend))
87117
end
88-
backend_modules() = to_backend_module.(supported_backends())
89118

119+
function setbackend!(backend::Module)
120+
global_current_backend[] = backend
121+
return
122+
end
90123

124+
"""
125+
Creates a temporary context for `device` and executes `f(context)` while this context is active.
126+
Context gets destroyed afterwards. Note, that creating a temporary context is expensive.
127+
"""
128+
function on_device(f, device = current_device())
129+
ctx = new_context(device)
130+
f(ctx)
131+
destroy!(ctx)
132+
return
133+
end
134+
135+
"""
136+
Returns all devices for the current backend.
137+
Can be filtered by passing `filter_funcs`, e.g. `is_gpu`, `is_cpu`, `(dev)-> has_atleast(dev, threads, 512)`
138+
"""
139+
function available_devices(filter_funcs...)
140+
result = []
141+
for device in current_backend().devices()
142+
if all(f-> f(device), filter_funcs)
143+
push!(result, device)
144+
end
145+
end
146+
result
147+
end
91148

92149

93-
function devices(filter_funcs...)
150+
"""
151+
Returns all devices from `backends = active_backends()`.
152+
Can be filtered by passing `filter_funcs`, e.g. `is_gpu`, `is_cpu`, `dev-> has_atleast(dev, threads, 512)`
153+
"""
154+
function all_devices(filter_funcs...; backends = active_backends())
94155
result = []
95-
for Module in backend_modules()
156+
for Module in backends
96157
for device in Module.devices()
97158
if all(f-> f(device), filter_funcs)
98159
push!(result, device)
@@ -113,11 +174,11 @@ function perbackend(f)
113174
end
114175

115176
"""
116-
Iterates through all available devices and calls `f` after initializing the current one!
177+
Iterates through all available devices and calls `f(context)` after initializing the standard context for that device.
117178
"""
118179
function forall_devices(f, filterfuncs...)
119-
for device in devices(filterfunc)
120-
make_current(device)
121-
f(device)
180+
for device in all_devices(filterfunc...)
181+
ctx = init(device)
182+
f(ctx)
122183
end
123184
end

src/backends/cudanative/cudanative.jl

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import CUDAdrv, CUDArt #, CUFFT
66

77
import GPUArrays: buffer, create_buffer, acc_mapreduce
88
import GPUArrays: Context, GPUArray, context, linear_index, gpu_call
9-
import GPUArrays: blas_module, blasbuffer, is_blas_supported, hasblas
9+
import GPUArrays: blas_module, blasbuffer, is_blas_supported, hasblas, init
1010
import GPUArrays: default_buffer_type, broadcast_index, is_fft_supported, unsafe_reinterpret
11-
11+
import GPUArrays: is_gpu, name, threads, blocks, global_memory, local_memory, new_context
12+
using GPUArrays: device_summary
1213

1314
using CUDAdrv: CuDefaultStream
1415

@@ -23,36 +24,83 @@ immutable CUContext <: Context
2324
device::CUDAdrv.CuDevice
2425
end
2526

26-
Base.show(io::IO, ctx::CUContext) = print(io, "CUContext")
27+
function Base.show(io::IO, ctx::CUContext)
28+
println(io, "CUDAnative context with:")
29+
device_summary(io, ctx.device)
30+
end
31+
2732

28-
function any_context()
29-
dev = CUDAdrv.CuDevice(0)
30-
ctx = CUDAdrv.CuContext(dev)
31-
CUContext(ctx, dev)
33+
devices() = CUDAdrv.devices()
34+
is_gpu(dev::CUDAdrv.CuDevice) = true
35+
name(dev::CUDAdrv.CuDevice) = CUDAdrv.name(dev)
36+
threads(dev::CUDAdrv.CuDevice) = CUDAdrv.attribute(dev, CUDAdrv.MAX_THREADS_PER_BLOCK)
37+
38+
function blocks(dev::CUDAdrv.CuDevice)
39+
(
40+
CUDAdrv.attribute(dev, CUDAdrv.MAX_BLOCK_DIM_X),
41+
CUDAdrv.attribute(dev, CUDAdrv.MAX_BLOCK_DIM_Y),
42+
CUDAdrv.attribute(dev, CUDAdrv.MAX_BLOCK_DIM_Z),
43+
)
3244
end
3345

46+
global_memory(dev::CUDAdrv.CuDevice) = CUDAdrv.totalmem(dev)
47+
local_memory(dev::CUDAdrv.CuDevice) = CUDAdrv.attribute(dev, CUDAdrv.TOTAL_CONSTANT_MEMORY)
48+
49+
3450
#const GLArrayImg{T, N} = GPUArray{T, N, gl.Texture{T, N}, GLContext}
3551
const CUArray{T, N, B} = GPUArray{T, N, B, CUContext} #, GLArrayImg{T, N}}
3652
const CUArrayBuff{T, N} = CUArray{T, N, CUDAdrv.CuArray{T, N}}
3753

3854

39-
global init, all_contexts, current_context
40-
let contexts = CUContext[]
41-
all_contexts() = copy(contexts)::Vector{CUContext}
42-
current_context() = last(contexts)::CUContext
43-
function init(;ctx = nothing)
44-
ctx = if ctx == nothing
45-
if isempty(contexts)
46-
any_context()
47-
else
48-
current_context()
49-
end
55+
global init, all_contexts, current_context, current_device
56+
57+
let contexts = Dict{CUDAdrv.CuDevice, CUContext}(), active_device = CUDAdrv.CuDevice[]
58+
59+
all_contexts() = values(contexts)
60+
function current_device()
61+
if isempty(active_device)
62+
push!(active_device, CUDAnative.default_device[])
5063
end
51-
GPUArrays.make_current(ctx)
52-
push!(contexts, ctx)
64+
active_device[]
65+
end
66+
current_context() = contexts[current_device()]
67+
function init(dev::CUDAdrv.CuDevice = current_device())
68+
if isempty(active_device)
69+
push!(active_device, dev)
70+
else
71+
active_device[] = dev
72+
end
73+
ctx = get!(()-> new_context(dev), contexts, dev)
74+
CUDAdrv.activate(ctx.ctx)
5375
ctx
5476
end
77+
78+
function destroy!(context::CUContext)
79+
# don't destroy primary device context
80+
dev = context.device
81+
if haskey(contexts, dev) && contexts[dev] == context
82+
error("Trying to destroy primary device context which is prohibited. Please use reset!(context)")
83+
end
84+
CUDAdrv.destroy!(context.ctx)
85+
return
86+
end
5587
end
88+
89+
function reset!(context::CUContext)
90+
dev = context.device
91+
CUDAdrv.destroy!(context.ctx)
92+
context.ctx = CUDAdrv.CuContext(dev)
93+
return
94+
end
95+
96+
function new_context(dev::CUDAdrv.CuDevice = current_device())
97+
cuctx = CUDAdrv.CuContext(dev)
98+
ctx = CUContext(cuctx, dev)
99+
CUDAdrv.activate(cuctx)
100+
return ctx
101+
end
102+
103+
56104
# synchronize
57105
function GPUArrays.synchronize{T, N}(x::CUArray{T, N})
58106
CUDAdrv.synchronize(context(x).ctx) # TODO figure out the diverse ways of synchronization

src/backends/julia/julia.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import GPUArrays: buffer, create_buffer, Context, context, mapidx, unpack_buffer
77
import GPUArrays: AbstractAccArray, AbstractSampler, acc_mapreduce, gpu_call
88
import GPUArrays: hasblas, blas_module, blasbuffer, default_buffer_type
99
import GPUArrays: unsafe_reinterpret, broadcast_index, linear_index
10+
import GPUArrays: is_cpu, name, threads, blocks, global_memory
1011

1112
import Base.Threads: @threads
1213

@@ -26,6 +27,16 @@ let contexts = JLContext[]
2627
end
2728
end
2829

30+
immutable JLDevice end
31+
32+
threads(x::JLDevice) = Base.Threads.nthreads()
33+
global_memory(x::JLDevice) = Sys.total_memory()
34+
free_global_memory(x::JLDevice) = Sys.free_memory()
35+
name(x::JLDevice) = Sys.cpu_info()[1].model # TODO,one could have multiple CPUs ?
36+
is_cpu(::JLDevice) = true
37+
38+
devices() = (JLDevice(),)
39+
2940

3041
immutable Sampler{T, N, Buffer} <: AbstractSampler{T, N}
3142
buffer::Buffer

src/backends/opencl/opencl.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import GPUArrays: Context, GPUArray, context, linear_index, free
1111
import GPUArrays: blasbuffer, blas_module, is_blas_supported, is_fft_supported
1212
import GPUArrays: synchronize, hasblas, LocalMemory, AccMatrix, AccVector, gpu_call
1313
import GPUArrays: default_buffer_type, broadcast_index, unsafe_reinterpret
14-
import GPUArrays: is_opencl, is_gpu, is_cpu
14+
import GPUArrays: is_gpu, is_cpu, name, threads, blocks, global_memory, local_memory
1515

1616
using Transpiler
1717
import Transpiler: cli, cli.get_global_id
@@ -55,12 +55,19 @@ function Base.show(io::IO, ctx::CLContext)
5555
end
5656

5757

58-
function devices()
59-
cl.devices()
60-
end
61-
is_opencl(ctx::CLContext) = true
62-
is_gpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :gpu
63-
is_cpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :cpu
58+
devices() = cl.devices()
59+
60+
is_gpu(dev::cl.Device) = cl.info(dev, :device_type) == :gpu
61+
is_cpu(dev::cl.Device) = cl.info(dev, :device_type) == :cpu
62+
63+
name(dev::cl.Device) = cl.info(dev, :name)
64+
65+
threads(dev::cl.Device) = cl.info(dev, :max_work_group_size) |> Int
66+
blocks(dev::cl.Device) = cl.info(dev, :max_work_item_size)
67+
68+
global_memory(dev::cl.Device) = cl.info(dev, :global_mem_size) |> Int
69+
local_memory(dev::cl.Device) = cl.info(dev, :local_mem_size) |> Int
70+
6471

6572

6673
global init, all_contexts, current_context

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ log_gpu_mem()
6464
include("fft.jl")
6565
end
6666
log_gpu_mem()
67+
68+
69+
using GPUArrays

0 commit comments

Comments
 (0)