Skip to content

Commit 493261f

Browse files
committed
start working on better device select
1 parent 4fa2706 commit 493261f

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

src/backends/backends.jl

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function default_backend()
88
JLBackend
99
end
1010
end
11+
1112
let compute_contexts = Context[]
1213
function current_context()
1314
if isempty(compute_contexts)
@@ -45,34 +46,78 @@ end
4546
function free(x::AbstractArray)
4647

4748
end
49+
#=
50+
Functions to select contexts
51+
=#
52+
53+
is_gpu(ctx) = false
54+
is_cpu(ctx) = false
55+
is_opencl(ctx) = false
56+
is_cudanative(ctx) = false
57+
is_julia(ctx) = false
58+
is_opengl(ctx) = false
59+
has_atleast(ctx, attribute, value) = error("has_atleast not implemented yet")
4860

4961
# BLAS support
5062
hasblas(x) = false
5163
include("blas.jl")
5264
include("supported_backends.jl")
5365
include("shared.jl")
5466

55-
function init(sym::Symbol, args...; kw_args...)
56-
if sym == :julia
57-
JLBackend.init(args...; kw_args...)
58-
elseif sym == :cudanative
59-
CUBackend.init(args...; kw_args...)
60-
elseif sym == :opencl
61-
CLBackend.init(args...; kw_args...)
62-
elseif sym == :opengl
63-
GLBackend.init(args...; kw_args...)
67+
function to_backend_module(backend::Symbol)
68+
if backend in supported_backends()
69+
if sym == :julia
70+
JLBackend
71+
elseif sym == :cudanative
72+
CUBackend
73+
elseif sym == :opencl
74+
CLBackend
75+
elseif sym == :opengl
76+
GLBackend
77+
end
6478
else
6579
error("$sym not a supported backend. Try one of: $(supported_backends())")
6680
end
6781
end
82+
function init(sym::Symbol, args...; kw_args...)
83+
backend_module(sym).init(args...; kw_args...)
84+
end
85+
function init(filterfuncs::Function...; kw_args...)
86+
init_from_device(first(devices(filterfuncs...)))
87+
end
88+
backend_modules() = to_backend_module.(supported_backends())
89+
90+
6891

6992

93+
function devices(filter_funcs...)
94+
result = []
95+
for Module in backend_modules()
96+
for device in Module.devices()
97+
if all(f-> f(device), filter_funcs)
98+
push!(result, device)
99+
end
100+
end
101+
end
102+
result
103+
end
104+
70105
"""
71106
Iterates through all backends and calls `f` after initializing the current one!
72107
"""
73108
function perbackend(f)
74109
for backend in supported_backends()
75110
ctx = GPUArrays.init(backend)
76-
f(backend)
111+
f(ctx)
112+
end
113+
end
114+
115+
"""
116+
Iterates through all available devices and calls `f` after initializing the current one!
117+
"""
118+
function forall_devices(f, filterfuncs...)
119+
for device in devices(filterfunc)
120+
make_current(device)
121+
f(device)
77122
end
78123
end

src/backends/opencl/opencl.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ import GPUArrays: Context, GPUArray, context, linear_index, free
1212
import GPUArrays: blasbuffer, blas_module, is_blas_supported, is_fft_supported
1313
import GPUArrays: synchronize, hasblas, LocalMemory, AccMatrix, AccVector, gpu_call
1414
import GPUArrays: default_buffer_type, broadcast_index, unsafe_reinterpret
15+
import GPUArrays: is_opencl, is_gpu, is_cpu
1516

1617
using Transpiler
1718
import Transpiler: cli, CLFunction, cli.get_global_id
1819

20+
21+
1922
immutable CLContext <: Context
2023
device::cl.Device
2124
context::cl.Context
@@ -52,6 +55,15 @@ function Base.show(io::IO, ctx::CLContext)
5255
print(io, "CLContext: $name")
5356
end
5457

58+
59+
function devices()
60+
cl.devices()
61+
end
62+
is_opencl(ctx::CLContext) = true
63+
is_gpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :gpu
64+
is_cpu(ctx::CLContext) = cl.info(ctx.device, :device_type) == :cpu
65+
66+
5567
global init, all_contexts, current_context
5668
let contexts = CLContext[]
5769
all_contexts() = copy(contexts)::Vector{CLContext}

0 commit comments

Comments
 (0)