Skip to content

Commit d1c7161

Browse files
committed
Merge branch 'sd/device_select' into sd/move_cl
2 parents 388bde0 + 493261f commit d1c7161

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

src/backends/backends.jl

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function default_backend()
88
JLBackend
99
end
1010
end
11+
1112
let compute_contexts = Context[]
1213
function current_context()
1314
if isempty(compute_contexts)
@@ -45,34 +46,78 @@ end
4546
function free(x::AbstractArray)
4647

4748
end
49+
#=
50+
Functions to select contexts
51+
=#
52+
53+
is_gpu(ctx) = false
54+
is_cpu(ctx) = false
55+
is_opencl(ctx) = false
56+
is_cudanative(ctx) = false
57+
is_julia(ctx) = false
58+
is_opengl(ctx) = false
59+
has_atleast(ctx, attribute, value) = error("has_atleast not implemented yet")
4860

4961
# BLAS support
5062
hasblas(x) = false
5163
include("blas.jl")
5264
include("supported_backends.jl")
5365
include("shared.jl")
5466

55-
function init(sym::Symbol, args...; kw_args...)
56-
if sym == :julia
57-
JLBackend.init(args...; kw_args...)
58-
elseif sym == :cudanative
59-
CUBackend.init(args...; kw_args...)
60-
elseif sym == :opencl
61-
CLBackend.init(args...; kw_args...)
62-
elseif sym == :opengl
63-
GLBackend.init(args...; kw_args...)
67+
function to_backend_module(backend::Symbol)
68+
if backend in supported_backends()
69+
if sym == :julia
70+
JLBackend
71+
elseif sym == :cudanative
72+
CUBackend
73+
elseif sym == :opencl
74+
CLBackend
75+
elseif sym == :opengl
76+
GLBackend
77+
end
6478
else
6579
error("$sym not a supported backend. Try one of: $(supported_backends())")
6680
end
6781
end
82+
function init(sym::Symbol, args...; kw_args...)
83+
backend_module(sym).init(args...; kw_args...)
84+
end
85+
function init(filterfuncs::Function...; kw_args...)
86+
init_from_device(first(devices(filterfuncs...)))
87+
end
88+
backend_modules() = to_backend_module.(supported_backends())
89+
90+
6891

6992

93+
function devices(filter_funcs...)
94+
result = []
95+
for Module in backend_modules()
96+
for device in Module.devices()
97+
if all(f-> f(device), filter_funcs)
98+
push!(result, device)
99+
end
100+
end
101+
end
102+
result
103+
end
104+
70105
"""
71106
Iterates through all backends and calls `f` after initializing the current one!
72107
"""
73108
function perbackend(f)
74109
for backend in supported_backends()
75110
ctx = GPUArrays.init(backend)
76-
f(backend)
111+
f(ctx)
112+
end
113+
end
114+
115+
"""
116+
Iterates through all available devices and calls `f` after initializing the current one!
117+
"""
118+
function forall_devices(f, filterfuncs...)
119+
for device in devices(filterfunc)
120+
make_current(device)
121+
f(device)
77122
end
78123
end

src/backends/opencl/opencl.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ import GPUArrays: Context, GPUArray, context, linear_index, free
1111
import GPUArrays: blasbuffer, blas_module, is_blas_supported, is_fft_supported
1212
import GPUArrays: synchronize, hasblas, LocalMemory, AccMatrix, AccVector, gpu_call
1313
import GPUArrays: default_buffer_type, broadcast_index, unsafe_reinterpret
14+
import GPUArrays: is_opencl, is_gpu, is_cpu
1415

1516
using Transpiler
1617
import Transpiler: cli, cli.get_global_id
1718

19+
20+
1821
immutable CLContext <: Context
1922
device::cl.Device
2023
context::cl.Context
@@ -51,6 +54,15 @@ function Base.show(io::IO, ctx::CLContext)
5154
print(io, "CLContext: $name")
5255
end
5356

57+
58+
function devices()
59+
cl.devices()
60+
end
61+
is_opencl(ctx::CLContext) = true
62+
is_gpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :gpu
63+
is_cpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :cpu
64+
65+
5466
global init, all_contexts, current_context
5567
let contexts = CLContext[]
5668
all_contexts() = copy(contexts)::Vector{CLContext}

0 commit comments

Comments
 (0)